Merge commit for internal changes

This commit is contained in:
Vijay Vasudevan 2017-05-04 21:31:53 -07:00
commit 15b8f3d65c
192 changed files with 8204 additions and 4109 deletions

2
configure vendored
View File

@ -385,7 +385,7 @@ fi
# Append CC optimization flags to bazel.rc # Append CC optimization flags to bazel.rc
for opt in $CC_OPT_FLAGS; do for opt in $CC_OPT_FLAGS; do
write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt' write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
done done
# Run the gen_git_source to create links where bazel can track dependencies for # Run the gen_git_source to create links where bazel can track dependencies for

View File

@ -58,6 +58,7 @@ tf_cuda_library(
"//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:loader",
"//tensorflow/cc:gradients", "//tensorflow/cc:gradients",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc:scope_internal", "//tensorflow/cc:scope_internal",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -91,6 +91,7 @@ cc_library(
deps = [ deps = [
":array_grad", ":array_grad",
":math_grad", ":math_grad",
":nn_grad",
], ],
) )

View File

@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceType& jit_device_name, const DeviceType& jit_device_name,
perftools::gputools::Platform* platform, perftools::gputools::Platform* platform,
Allocator* xla_allocator) Allocator* xla_allocator)
: LocalDevice(options, attrs, xla_allocator), : LocalDevice(options, attrs),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name), jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator), xla_allocator_(xla_allocator),

View File

@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
options, options,
Device::BuildDeviceAttributes( Device::BuildDeviceAttributes(
"", type, Bytes(256 << 20), DeviceLocality(), "", type, Bytes(256 << 20), DeviceLocality(),
strings::StrCat("device: XLA compilation device ", type.type())), strings::StrCat("device: XLA compilation device ", type.type()))),
cpu_allocator()),
allocator_(new XlaCompilationAllocator()) {} allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {} XlaCompilationDevice::~XlaCompilationDevice() {}

View File

@ -668,6 +668,14 @@ class ComputationBuilder {
// then Build() should be used instead. // then Build() should be used instead.
Computation BuildAndNoteError(); Computation BuildAndNoteError();
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// ComputationDataHandle and inform the user of the error that occurred while
// building the computation when they make a final call to Build().
//
// See also set_die_immediately_on_error().
Status first_error() const { return first_error_; }
private: private:
using PopulateLiteral = std::function<void(Literal*)>; using PopulateLiteral = std::function<void(Literal*)>;

View File

@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name,
if (&argument == retval) { if (&argument == retval) {
continue; continue;
} }
compute_function_->setDoesNotAlias(argument.getArgNo() + 1); compute_function_->addAttribute(argument.getArgNo() + 1,
llvm::Attribute::NoAlias);
} }
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(

View File

@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
ir_emitter_context_->buffer_assignment().GetTempAllocation()) { ir_emitter_context_->buffer_assignment().GetTempAllocation()) {
kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size()); kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size());
} }
kernel->setDoesNotAlias(temp_buffer_arg_no + 1); kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias);
// Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
// treats it as a CUDA kernel. // treats it as a CUDA kernel.

View File

@ -705,7 +705,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()) && CHECK(ShapeUtil::IsArray(instruction->shape()) &&
ShapeUtil::IsArray(operand->shape())); ShapeUtil::IsArray(operand->shape()));
if (instruction->IsElementwiseOnOperand(operand_no) && if ((instruction->IsElementwiseOnOperand(operand_no) ||
InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) &&
!ShapeUtil::IsScalar(operand->shape()) && !ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape())) { ShapeUtil::Rank(instruction->shape())) {

View File

@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface {
return Status::OK(); return Status::OK();
} }
// This method can be overriden to mark instructions as requiring the operands
// to have the same layout as the result, for performance or correctness. This
// will propagate constraints through the instruction from the result into the
// operands.
virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
const HloInstruction* instruction) {
return false;
}
// Construct contraints and assign layouts to all instructions in the // Construct contraints and assign layouts to all instructions in the
// computation satisfying the given ComputationLayout. Layouts constraints are // computation satisfying the given ComputationLayout. Layouts constraints are
// added, then propagated until all LogicalBuffers in the computation are // added, then propagated until all LogicalBuffers in the computation are

View File

@ -244,8 +244,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
} }
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
return InvalidArgument( return InvalidArgument(
"cannot concatenate arrays with different ranks: %lld vs %lld", "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); "(%s)",
ShapeUtil::Rank(*arg_shape),
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
ShapeUtil::HumanString(*shape).c_str());
} }
if (arg_shape->element_type() != shape->element_type()) { if (arg_shape->element_type() != shape->element_type()) {
return InvalidArgument( return InvalidArgument(

View File

@ -118,6 +118,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/types.proto" "tensorflow/core/framework/types.proto"
"tensorflow/core/framework/versions.proto" "tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/lib/core/error_codes.proto"
"tensorflow/core/protobuf/cluster.proto"
"tensorflow/core/protobuf/config.proto" "tensorflow/core/protobuf/config.proto"
"tensorflow/core/protobuf/debug.proto" "tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/rewriter_config.proto" "tensorflow/core/protobuf/rewriter_config.proto"

View File

@ -22,6 +22,7 @@ set(tf_op_lib_names
"image_ops" "image_ops"
"io_ops" "io_ops"
"linalg_ops" "linalg_ops"
"lookup_ops"
"logging_ops" "logging_ops"
"math_ops" "math_ops"
"nn_ops" "nn_ops"

View File

@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator")
add_python_module("tensorflow/python/estimator/export") add_python_module("tensorflow/python/estimator/export")
add_python_module("tensorflow/python/estimator/inputs") add_python_module("tensorflow/python/estimator/inputs")
add_python_module("tensorflow/python/estimator/inputs/queues") add_python_module("tensorflow/python/estimator/inputs/queues")
add_python_module("tensorflow/python/feature_column")
add_python_module("tensorflow/python/framework") add_python_module("tensorflow/python/framework")
add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/grappler")
add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests")
@ -596,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("linalg_ops")
GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("logging_ops")
GENERATE_PYTHON_OP_LIB("lookup_ops")
GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("nn_ops")
GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("parsing_ops")
GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("random_ops")

View File

@ -710,25 +710,6 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "identity_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/identity_test.py"],
additional_deps = [
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test( cuda_py_test(
name = "inline_test", name = "inline_test",
size = "small", size = "small",

View File

@ -25,6 +25,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import * from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.ops.distributions.bernoulli import * from tensorflow.python.ops.distributions.bernoulli import *
from tensorflow.python.ops.distributions.beta import * from tensorflow.python.ops.distributions.beta import *
from tensorflow.python.ops.distributions.categorical import * from tensorflow.python.ops.distributions.categorical import *
from tensorflow.python.ops.distributions.conditional_distribution import *
from tensorflow.python.ops.distributions.dirichlet import * from tensorflow.python.ops.distributions.dirichlet import *
from tensorflow.python.ops.distributions.dirichlet_multinomial import * from tensorflow.python.ops.distributions.dirichlet_multinomial import *
from tensorflow.python.ops.distributions.distribution import * from tensorflow.python.ops.distributions.distribution import *
@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import *
from tensorflow.python.ops.distributions.multinomial import * from tensorflow.python.ops.distributions.multinomial import *
from tensorflow.python.ops.distributions.normal import * from tensorflow.python.ops.distributions.normal import *
from tensorflow.python.ops.distributions.student_t import * from tensorflow.python.ops.distributions.student_t import *
from tensorflow.python.ops.distributions.transformed_distribution import *
from tensorflow.python.ops.distributions.uniform import * from tensorflow.python.ops.distributions.uniform import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member

View File

@ -23,9 +23,9 @@ import itertools
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -21,9 +21,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from scipy import special from scipy import special
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test from tensorflow.python.platform import test
rng = np.random.RandomState(42) rng = np.random.RandomState(42)

View File

@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
from tensorflow.contrib.distributions.python.ops.bijectors.identity import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member

View File

@ -1,29 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Identity bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ["Identity"]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import conditional_distribution from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import linalg from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import logistic from tensorflow.contrib.distributions.python.ops import logistic
from tensorflow.contrib.distributions.python.ops import transformed_distribution
# Bijectors must be directly imported because `remove_undocumented` prevents # Bijectors must be directly imported because `remove_undocumented` prevents
# individual file imports. # individual file imports.
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import student_t
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -108,6 +108,7 @@ tf_custom_op_py_library(
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/feature_column",
"@six_archive//:six", "@six_archive//:six",
], ],
) )

View File

@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -1497,7 +1499,10 @@ def _real_valued_var_len_column(column_name,
is_sparse) is_sparse)
class _RealValuedColumn(_FeatureColumn, collections.namedtuple( class _RealValuedColumn(
_FeatureColumn,
fc_core._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
"_RealValuedColumn", "_RealValuedColumn",
["column_name", "dimension", "default_value", "dtype", "normalizer"])): ["column_name", "dimension", "default_value", "dtype", "normalizer"])):
"""Represents a real valued feature column also known as continuous features. """Represents a real valued feature column also known as continuous features.
@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
def _to_dense_tensor(self, input_tensor): def _to_dense_tensor(self, input_tensor):
return input_tensor return input_tensor
@property
def _variable_shape(self):
return tensor_shape.TensorShape((self.dimension))
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
del weight_collections
del trainable
return inputs.get(self)
def _transform_feature(self, inputs):
return math_ops.to_float(
self._normalized_input_tensor(inputs.get(self.name)))
@property
def _parse_example_config(self):
return self.config
def real_valued_column(column_name, def real_valued_column(column_name,
dimension=1, dimension=1,

View File

@ -27,14 +27,15 @@ from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2 from tensorflow.core.example import feature_pb2
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
@ -223,7 +224,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1) self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output) self.assertIn(keys_sparse, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[keys_sparse].indices.eval(), self.assertAllEqual(output[keys_sparse].indices.eval(),
@ -241,7 +242,7 @@ class TransformerTest(test.TestCase):
output = feature_column_ops._Transformer(features).transform(keys_sparse) output = feature_column_ops._Transformer(features).transform(keys_sparse)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor. # While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.dtype, dtypes.int64) self.assertEqual(output.dtype, dtypes.int64)
@ -310,7 +311,7 @@ class TransformerTest(test.TestCase):
self.assertIn(weighted_ids, output) self.assertIn(weighted_ids, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval()) ids_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][0].indices.eval(), self.assertAllEqual(output[weighted_ids][0].indices.eval(),
@ -340,7 +341,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1) self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output) self.assertIn(vocab_sparse, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(), self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -362,7 +363,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1) self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output) self.assertIn(vocab_sparse, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(), self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -386,7 +387,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1) self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output) self.assertIn(vocab_sparse, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(), self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -408,7 +409,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1) self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output) self.assertIn(vocab_sparse, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(), self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -600,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column, embedding_column, real_valued_column]) one_hot_column, embedding_column, real_valued_column])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
def testRealValuedColumn(self): def testRealValuedColumn(self):
@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued]) [real_valued])
with self.test_session(): with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval()) self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnWithMultiDimensions(self): def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2) real_valued = feature_column.real_valued_column("price", 2)
@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued]) [real_valued])
with self.test_session(): with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval()) self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnSparse(self): def testRealValuedColumnSparse(self):
sparse_real_valued = feature_column._real_valued_var_len_column( sparse_real_valued = feature_column._real_valued_var_len_column(
@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued]) [real_valued])
with self.test_session(): with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2) self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self): def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
real_valued = feature_column.real_valued_column( real_valued = feature_column.real_valued_column(
@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued]) [real_valued])
with self.test_session(): with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2) self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testBucketizedColumnWithNormalizerSucceedsForDNN(self): def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
bucket = feature_column.bucketized_column( bucket = feature_column.bucketized_column(
@ -697,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_column]) [one_hot_column])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval()) output.eval())
@ -715,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval()) output.eval())
@ -733,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval()) output.eval())
@ -767,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse]) [one_hot_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape) self.assertAllEqual([3, 10], output.eval().shape)
def testEmbeddingColumnSucceedsForDNN(self): def testEmbeddingColumnSucceedsForDNN(self):
@ -874,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse]) [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10]) self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self): def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
@ -897,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse]) [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10]) self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self): def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
@ -948,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
"Error creating input layer for column: ids_weighted_by_weights"): "Error creating input layer for column: ids_weighted_by_weights"):
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
feature_column_ops.input_from_feature_columns(features, [weighted_ids]) feature_column_ops.input_from_feature_columns(features, [weighted_ids])
def testCrossedColumnFailsForDNN(self): def testCrossedColumnFailsForDNN(self):
@ -1055,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse]) [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
# score: (sum of weights) # score: (sum of weights)
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]]) self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
@ -1293,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor) model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, 4]) expected_input_shape = np.array([4, 3, 4])
@ -1327,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor) model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, hash_buckets]) expected_input_shape = np.array([4, 3, hash_buckets])
@ -1357,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor) model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape) self.assertAllEqual(expected_input_shape, model_input.shape)
@ -1386,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor) model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape) self.assertAllEqual(expected_input_shape, model_input.shape)
@ -1416,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights) embedding_weights)
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
expected_input_shape = [4, 3, embedding_dimension] expected_input_shape = [4, 3, embedding_dimension]
@ -1483,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor) model_input = sess.run(model_input_tensor)
expected_input_shape = [ expected_input_shape = [
@ -1564,7 +1581,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5) features, [weighted_ids], num_outputs=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
def testWeightedSparseColumnWithDenseInputTensor(self): def testWeightedSparseColumnWithDenseInputTensor(self):
@ -1580,7 +1597,7 @@ class WeightedSumTest(test.TestCase):
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
def testCrossedColumn(self): def testCrossedColumn(self):
@ -1634,7 +1651,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1)) features, [movies], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.initialize_all_variables().run() variables_lib.initialize_all_variables().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0] weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (3, 1)) self.assertEqual(weights.get_shape(), (3, 1))
@ -1709,7 +1726,7 @@ class WeightedSumTest(test.TestCase):
features, [age, language], num_outputs=1)) features, [age, language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]]) self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1749,7 +1766,7 @@ class WeightedSumTest(test.TestCase):
self.assertEqual(len(variables), 1) self.assertEqual(len(variables), 1)
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]]) self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1813,7 +1830,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_language], num_outputs=1)) features, [weighted_language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]]) self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1841,7 +1858,7 @@ class WeightedSumTest(test.TestCase):
features, [language], num_outputs=1)) features, [language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
# score: 0.1 + language_weight['hindi'] + language_weight['english'] # score: 0.1 + language_weight['hindi'] + language_weight['english']
sess.run(bias.assign([0.1])) sess.run(bias.assign([0.1]))
@ -1864,7 +1881,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1)) features, [movies], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0] weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (15, 1)) self.assertEqual(weights.get_shape(), (15, 1))
@ -1898,7 +1915,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1)) features, [country_language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0] weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4)) sess.run(weights.assign(weights + 0.4))
@ -1922,7 +1939,7 @@ class WeightedSumTest(test.TestCase):
features, [language_language], num_outputs=1)) features, [language_language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[language_language][0] weights = column_to_variable[language_language][0]
sess.run(weights.assign(weights + 0.4)) sess.run(weights.assign(weights + 0.4))
@ -1955,7 +1972,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1)) features, [country_language], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0] weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4)) sess.run(weights.assign(weights + 0.4))
@ -1996,7 +2013,7 @@ class WeightedSumTest(test.TestCase):
scope=scope)) scope=scope))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertEqual(2, len(column_to_variable[country])) self.assertEqual(2, len(column_to_variable[country]))
self.assertEqual(3, len(column_to_variable[language])) self.assertEqual(3, len(column_to_variable[language]))
@ -2033,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, incomes], num_outputs=1)) features, [country, age, incomes], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
incomes_weights = column_to_variable[incomes][0] incomes_weights = column_to_variable[incomes][0]
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]])) sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
@ -2069,7 +2086,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, height, incomes], num_outputs=5)) features, [country, age, height, incomes], num_outputs=5))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
height_weights = column_to_variable[height][0] height_weights = column_to_variable[height][0]
sess.run( sess.run(
@ -2099,7 +2116,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket], num_outputs=1)) features, [bucket], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3], sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
[0.4]])) [0.4]]))
@ -2127,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=1)) features, [bucket, country], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 1 # dimension = 2, bucket_size = 4, num_classes = 1
sess.run(column_to_variable[bucket][0].assign( sess.run(column_to_variable[bucket][0].assign(
@ -2156,7 +2173,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=5)) features, [bucket, country], num_outputs=5))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 5 # dimension = 2, bucket_size = 4, num_classes = 5
sess.run(column_to_variable[bucket][0].assign( sess.run(column_to_variable[bucket][0].assign(
@ -2192,7 +2209,7 @@ class WeightedSumTest(test.TestCase):
features, [country_price], num_outputs=1)) features, [country_price], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[country_price][0] weights = column_to_variable[country_price][0]
sess.run(weights.assign(weights + 0.4)) sess.run(weights.assign(weights + 0.4))
@ -2231,7 +2248,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language_price], num_outputs=1)) features, [country_language_price], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language_price][0] weights = column_to_variable[country_language_price][0]
sess.run(weights.assign(weights + 0.4)) sess.run(weights.assign(weights + 0.4))
@ -2255,7 +2272,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1)) features, [product], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0] product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2270,7 +2287,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1)) features, [product], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0] product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2285,7 +2302,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1)) features, [product], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0] product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.6], [0.7]]) self.assertAllClose(output.eval(), [[0.6], [0.7]])
@ -2306,7 +2323,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1)) features, [product], num_outputs=1))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0] product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2318,7 +2335,7 @@ class WeightedSumTest(test.TestCase):
features, [feature_column.real_valued_column("age")], num_outputs=3) features, [feature_column.real_valued_column("age")], num_outputs=3)
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3])) sess.run(bias.assign([0.1, 0.2, 0.3]))
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
@ -2332,7 +2349,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3)) features, [column], num_outputs=3))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0] weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (1, 3)) self.assertEqual(weights.get_shape(), (1, 3))
sess.run(weights.assign([[0.01, 0.03, 0.05]])) sess.run(weights.assign([[0.01, 0.03, 0.05]]))
@ -2356,7 +2373,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3)) features, [column], num_outputs=3))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0] weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3)) self.assertEqual(weights.get_shape(), (5, 3))
sess.run( sess.run(
@ -2382,7 +2399,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3)) features, [column], num_outputs=3))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0] weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3)) self.assertEqual(weights.get_shape(), (5, 3))
@ -2422,7 +2439,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3)) features, [column], num_outputs=3))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0] weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3)) self.assertEqual(weights.get_shape(), (5, 3))
@ -2451,7 +2468,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3)) features, [column], num_outputs=3))
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0] weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3)) self.assertEqual(weights.get_shape(), (5, 3))
@ -2516,7 +2533,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn(bucket, output) self.assertIn(bucket, output)
self.assertIn(wire_cast, output) self.assertIn(wire_cast, output)
with self.test_session(): with self.test_session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0]) self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])

View File

@ -46,7 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
Args: Args:
uniform: Whether to use uniform or normal distributed random initialization. uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See seed: A Python integer. Used to create random seeds. See
@{set_random_seed} for behavior. @{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported. dtype: The data type. Only floating point types are supported.
Returns: Returns:
@ -96,7 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
uniform: Whether to use uniform or normal distributed random initialization. uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See seed: A Python integer. Used to create random seeds. See
@{set_random_seed} for behavior. @{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported. dtype: The data type. Only floating point types are supported.
Returns: Returns:

View File

@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
self.context_feature_columns) self.context_feature_columns)
with self.test_session() as sess: with self.test_session() as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.tables_initializer()) sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input) sequence_input_val = sess.run(sequence_input)
expected_shape = np.array([ expected_shape = np.array([
3, # expected batch size 3, # expected batch size
@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
# Obtain values of activations and final state. # Obtain values of activations and final state.
with session.Session() as sess: with session.Session() as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.tables_initializer()) sess.run(lookup_ops.tables_initializer())
activations, final_state = sess.run([activations_t, final_state_t]) activations, final_state = sess.run([activations_t, final_state_t])
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])

View File

@ -57,7 +57,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources from tensorflow.python.ops import resources
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator):
init_op = control_flow_ops.group( init_op = control_flow_ops.group(
variables.local_variables_initializer(), variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()), resources.initialize_resources(resources.shared_resources()),
data_flow_ops.tables_initializer()) lookup_ops.tables_initializer())
# Perform the export # Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir) builder = saved_model_builder.SavedModelBuilder(export_dir)

View File

@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),)) logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session(): with session.Session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual( self.assertAllEqual(
[0, 2], [0, 2],
model_fn_ops.predictions["classes"].eval()) model_fn_ops.predictions["classes"].eval())
@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),)) logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session(): with session.Session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual( self.assertAllEqual(
[b"key0", b"key2"], [b"key0", b"key2"],
model_fn_ops.predictions["classes"].eval()) model_fn_ops.predictions["classes"].eval())
@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.),)) logits=((1., 0., 0.),))
with session.Session(): with session.Session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op) self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self) _assert_no_variables(self)
_assert_summary_tags(self, ["loss"]) _assert_summary_tags(self, ["loss"])
@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, train_op_fn=head_lib.no_op_train_fn,
logits=((0., 0., 1.),)) logits=((0., 0., 1.),))
with session.Session(): with session.Session():
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op) self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self) _assert_no_variables(self)
_assert_summary_tags(self, ["loss"]) _assert_summary_tags(self, ["loss"])

View File

@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.initialize_all_tables()) sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time) features_val = sess.run(features_by_time)
self.assertAllEqual(expected, features_val) self.assertAllEqual(expected, features_val)
@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.initialize_all_tables()) sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run( actual_sequence, actual_context = sess.run(
[sequence, context]) [sequence, context])
assert_equal(expected_sequence, actual_sequence) assert_equal(expected_sequence, actual_sequence)

View File

@ -647,6 +647,10 @@ class Experiment(object):
if _sentinel is not None: if _sentinel is not None:
raise ValueError("_call_train should be called with keyword args only") raise ValueError("_call_train should be called with keyword args only")
# Estimator in core cannot work with monitors. We need to convert them
# to hooks. For Estimator in contrib, it is converted internally. So, it is
# safe to convert for both cases.
hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
if self._core_estimator_used: if self._core_estimator_used:
return self._estimator.train(input_fn=input_fn, return self._estimator.train(input_fn=input_fn,
steps=steps, steps=steps,

View File

@ -24,7 +24,6 @@ import time
from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import experiment from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import run_config from tensorflow.contrib.learn.python.learn import run_config
from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.eval_count) self.assertEqual(1, est.eval_count)
self.assertEqual(1, len(est.monitors)) self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks) self.assertEqual([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) self.assertTrue(isinstance(est.monitors[0],
session_run_hook.SessionRunHook))
def test_train_hooks_extend_does_not_mutate_input_hooks(self): def test_train_hooks_extend_does_not_mutate_input_hooks(self):
for est in self._estimators_for_tests(): for est in self._estimators_for_tests():
@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.export_count) self.assertEqual(1, est.export_count)
self.assertEqual(1, len(est.monitors)) self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks) self.assertEqual([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) self.assertTrue(isinstance(est.monitors[0],
session_run_hook.SessionRunHook))
def test_train_and_evaluate_with_no_eval_during_training(self): def test_train_and_evaluate_with_no_eval_during_training(self):
for est in self._estimators_for_tests(): for est in self._estimators_for_tests():

View File

@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources from tensorflow.python.ops import resources
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -429,11 +429,14 @@ def _get_ready_op():
def _get_local_init_op(): def _get_local_init_op():
"""Returns the local init ops to initialize tables and local variables."""
local_init_op = _get_first_op_from_collection( local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP) ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None: if local_init_op is None:
op_list = [variables.local_variables_initializer(), op_list = [
data_flow_ops.tables_initializer()] variables.local_variables_initializer(),
lookup_ops.tables_initializer()
]
if op_list: if op_list:
local_init_op = control_flow_ops.group(*op_list) local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
else: else:
session.run(variables.global_variables_initializer()) session.run(variables.global_variables_initializer())
session.run(variables.local_variables_initializer()) session.run(variables.local_variables_initializer())
session.run(data_flow_ops.tables_initializer()) session.run(lookup_ops.tables_initializer())
coord = coordinator.Coordinator() coord = coordinator.Coordinator()
threads = None threads = None
try: try:

View File

@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import saver as tf_saver
@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
with graph.as_default(): with graph.as_default():
with tf_session.Session('') as session: with tf_session.Session('') as session:
variables.local_variables_initializer() variables.local_variables_initializer()
data_flow_ops.tables_initializer() lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path) saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver) export = exporter.Exporter(saver)
export.init(init_op=control_flow_ops.group( export.init(
init_op=control_flow_ops.group(
variables.local_variables_initializer(), variables.local_variables_initializer(),
data_flow_ops.tables_initializer()), lookup_ops.tables_initializer()),
default_graph_signature=default_graph_signature, default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures, named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection( assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(), return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep) session, exports_to_keep=exports_to_keep)

View File

@ -13,19 +13,10 @@ py_library(
name = "lookup_py", name = "lookup_py",
srcs = [ srcs = [
"__init__.py", "__init__.py",
"lookup_ops.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python/feature_column:lookup_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
], ],
) )
@ -39,11 +30,11 @@ py_test(
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client", "//tensorflow/python:client",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variables", "//tensorflow/python:variables",

View File

@ -47,7 +47,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=unused-import,wildcard-import # pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.lookup.lookup_ops import * from tensorflow.python.feature_column.lookup_ops import *
# pylint: enable=unused-import,wildcard-import # pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented

View File

@ -31,7 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import saver from tensorflow.python.training import saver
@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
table3 = lookup.HashTable( table3 = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), default_val) lookup.KeyValueTensorInitializer(keys, values), default_val)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(3, table1.size().eval()) self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval()) self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval()) self.assertAllEqual(3, table3.size().eval())
@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self): def test_int32_index_table_from_file(self):
@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_file(self): def test_int64_index_table_from_file(self):
@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_file_with_default_value(self): def test_index_table_from_file_with_default_value(self):
@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval()) self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_file_with_oov_buckets(self): def test_index_table_from_file_with_oov_buckets(self):
@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual( self.assertAllEqual(
( (
1, # From vocabulary file. 1, # From vocabulary file.
@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, -1, -1), ids.eval()) self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval()) self.assertEqual(2, table.size().eval())
@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval()) self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval()) self.assertEqual(3, table.size().eval())
@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self): def test_int32_index_table_from_tensor_with_tensor_init(self):
@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self): def test_int64_index_table_from_tensor_with_tensor_init(self):
@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_with_default_value(self): def test_index_table_from_tensor_with_default_value(self):
@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval()) self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self): def test_index_table_from_tensor_missing_mapping(self):
@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertRaises(errors_impl.OpError, ids.eval) self.assertRaises(errors_impl.OpError, ids.eval)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
errors_impl.OpError, "keys and values cannot be empty"): errors_impl.OpError, "keys and values cannot be empty"):
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self): def test_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session(): with self.test_session():
@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase):
indices = lookup.string_to_index(feats, mapping=mapping_strings) indices = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, indices.eval) self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), indices.eval()) self.assertAllEqual((1, 2, -1), indices.eval())
@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase):
_ = lookup.string_to_index(feats, mapping=mapping_strings) _ = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run) lookup_ops.tables_initializer().run)
def test_string_to_index_with_default_value(self): def test_string_to_index_with_default_value(self):
default_value = -42 default_value = -42
@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase):
feats, mapping=mapping_strings, default_value=default_value) feats, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, indices.eval) self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), indices.eval()) self.assertAllEqual((1, 2, default_value), indices.eval())
@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file) vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval()) features.eval())
@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file, default_value=default_value) vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval()) features.eval())
@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
default_value=default_value) default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", default_value, default_value), self.assertAllEqual((b"salad", default_value, default_value),
features.eval()) features.eval())
@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
init = data_flow_ops.tables_initializer() init = lookup_ops.tables_initializer()
self.assertRaisesRegexp(errors_impl.InvalidArgumentError, self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", init.run) "Invalid vocab_size", init.run)
@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
features = table.lookup(indices) features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval()) features.eval())
@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
mapping=mapping_strings) mapping=mapping_strings)
indices = constant_op.constant([0, 1, 4], dtypes.int64) indices = constant_op.constant([0, 1, 4], dtypes.int64)
features = table.lookup(indices) features = table.lookup(indices)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
def test_index_to_string_with_default_value(self): def test_index_to_string_with_default_value(self):
@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features = table.lookup(indices) features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval) self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval()) features.eval())
@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase):
feats = lookup.index_to_string(indices, mapping=mapping_strings) feats = lookup.index_to_string(indices, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, feats.eval) self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
feats.eval()) feats.eval())
@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase):
mapping_strings = constant_op.constant(["hello", "hello"]) mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64) indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings) feats = lookup.index_to_string(indices, mapping=mapping_strings)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
self.assertRaises(errors_impl.OpError, self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run) lookup_ops.tables_initializer().run)
def test_index_to_string_with_default_value(self): def test_index_to_string_with_default_value(self):
default_value = b"NONE" default_value = b"NONE"
@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase):
indices, mapping=mapping_strings, default_value=default_value) indices, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, feats.eval) self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value, default_value,
shared_name=shared_name) shared_name=shared_name)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
input_string = constant_op.constant(["brain", "salad", "tank"]) input_string = constant_op.constant(["brain", "salad", "tank"])
@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec((1, 2)), hasher_spec=lookup.StrongHashSpec((1, 2)),
name="table2") name="table2")
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
input_string = constant_op.constant( input_string = constant_op.constant(
["fruit", "brain", "salad", "surgery", "UNK"]) ["fruit", "brain", "salad", "surgery", "UNK"])
@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
default_value2), default_value2),
oov_buckets) oov_buckets)
data_flow_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
input_string_1 = constant_op.constant( input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"]) ["brain", "salad", "surgery", "UNK"])

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc tensorflow/core/protobuf/debug.pb.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h tensorflow/core/protobuf/rewriter_config.pb.h

View File

@ -1,6 +1,7 @@
tensorflow/core/util/saved_tensor_slice.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/rewriter_config.proto tensorflow/core/protobuf/rewriter_config.proto

View File

@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
name=name_scope) name=name_scope)
def sparse_recall_at_top_k(labels,
top_k_predictions,
class_id=None,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
"""Computes recall@k of top-k predictions with respect to sparse labels.
If `class_id` is specified, we calculate recall by considering only the
entries in the batch for which `class_id` is in the label, and computing
the fraction of them for which `class_id` is in the top-k `predictions`.
If `class_id` is not specified, we'll calculate recall as how often on
average a class among the labels of a batch entry is in the top-k
`predictions`.
`sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
and `false_negative_at_<k>`, that are used to compute the recall_at_k
frequency. This frequency is ultimately returned as `recall_at_<k>`: an
idempotent operation that simply divides `true_positive_at_<k>` by total
(`true_positive_at_<k>` + `false_negative_at_<k>`).
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the
`recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
true positives and false negatives weighted by `weights`. Then `update_op`
increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
values.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: `int64` `Tensor` or `SparseTensor` with shape
[D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
target classes for the associated prediction. Commonly, N=1 and `labels`
has shape [batch_size, num_labels]. [D1, ... DN] must match
`top_k_predictions`. Values should be in range [0, num_classes), where
num_classes is the last dimension of `predictions`. Values outside this
range always count towards `false_negative_at_<k>`.
top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
The final dimension contains the indices of top-k labels. [D1, ... DN]
must match `labels`.
class_id: Integer class ID for which we want binary metrics. This should be
in range [0, num_classes), where num_classes is the last dimension of
`predictions`. If class_id is outside this range, the method returns NAN.
weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
`labels`. If the latter, it must be broadcastable to `labels` (i.e., all
dimensions must be either `1`, or the same as the corresponding `labels`
dimension).
metrics_collections: An optional list of collections that values should
be added to.
updates_collections: An optional list of collections that updates should
be added to.
name: Name of new update operation, and namespace for other dependent ops.
Returns:
recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
by the sum of `true_positives` and `false_negatives`.
update_op: `Operation` that increments `true_positives` and
`false_negatives` variables appropriately, and whose value matches
`recall`.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('recall', class_id=class_id)
with ops.name_scope(name, default_name, (top_k_predictions, labels,
weights)) as name_scope:
return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
labels=labels,
predictions_idx=top_k_predictions,
class_id=class_id,
weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections,
name=name_scope)
def streaming_sparse_average_precision_at_k(predictions, def streaming_sparse_average_precision_at_k(predictions,
labels, labels,
k, k,
@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
__all__ = [ __all__ = [
'aggregate_metric_map', 'aggregate_metric_map',
'aggregate_metrics', 'aggregate_metrics',
'sparse_recall_at_top_k',
'streaming_accuracy', 'streaming_accuracy',
'streaming_auc', 'streaming_auc',
'streaming_false_negatives', 'streaming_false_negatives',
@ -2310,7 +2392,9 @@ __all__ = [
'streaming_root_mean_squared_error', 'streaming_root_mean_squared_error',
'streaming_sensitivity_at_specificity', 'streaming_sensitivity_at_specificity',
'streaming_sparse_average_precision_at_k', 'streaming_sparse_average_precision_at_k',
'streaming_sparse_average_precision_at_top_k',
'streaming_sparse_precision_at_k', 'streaming_sparse_precision_at_k',
'streaming_sparse_precision_at_top_k',
'streaming_sparse_recall_at_k', 'streaming_sparse_recall_at_k',
'streaming_specificity_at_sensitivity', 'streaming_specificity_at_sensitivity',
'streaming_true_negatives', 'streaming_true_negatives',

View File

@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase):
self.assertEqual(expected, update.eval()) self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval()) self.assertEqual(expected, metric.eval())
def _test_sparse_recall_at_top_k(self,
labels,
top_k_predictions,
expected,
class_id=None,
weights=None):
with ops.Graph().as_default() as g, self.test_session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metric_ops.sparse_recall_at_top_k(
labels=labels,
top_k_predictions=constant_op.constant(top_k_predictions,
dtypes_lib.int32),
class_id=class_id,
weights=weights)
# Fails without initialized vars.
self.assertRaises(errors_impl.OpError, metric.eval)
self.assertRaises(errors_impl.OpError, update.eval)
variables.variables_initializer(variables.local_variables()).run()
# Run per-step op and assert expected values.
if math.isnan(expected):
self.assertTrue(math.isnan(update.eval()))
self.assertTrue(math.isnan(metric.eval()))
else:
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
def test_one_label_at_k1_nan(self): def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]]) [[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64) dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (-1, 0, 1, 4): for class_id in (-1, 0, 1, 4):
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=class_id) predictions, labels, k=1, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_one_label_at_k1_no_predictions(self): def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]]) [[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64) dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 0 predictions. # Class 2: 0 predictions.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0, class_id=2) predictions, labels, k=1, expected=0.0, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0, class_id=2)
def test_one_label_at_k1(self): def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]]) [[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64) dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct. # Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, class_id=3) predictions, labels, k=1, expected=1.0 / 1, class_id=3)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, class_id=3)
# All classes: 2 labels, 2 predictions, 1 correct. # All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2) predictions, labels, k=1, expected=1.0 / 2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2)
def test_one_label_at_k1_weighted(self): def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]]) [[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64) dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct. # Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1, expected=1.0 / 1,
class_id=3, class_id=3,
weights=(1.0,)) weights=(1.0,))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1, expected=1.0 / 1,
class_id=3, class_id=3,
weights=(2.0,)) weights=(2.0,))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(2.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN, expected=NAN,
class_id=3, class_id=3,
weights=(0.0, 0.0)) weights=(0.0, 0.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=3,
weights=(0.0, 0.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN, expected=NAN,
class_id=3, class_id=3,
weights=(0.0, 1.0)) weights=(0.0, 1.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=3,
weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1, expected=1.0 / 1,
class_id=3, class_id=3,
weights=(1.0, 0.0)) weights=(1.0, 0.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1, expected=1.0 / 1,
class_id=3, class_id=3,
weights=(1.0, 1.0)) weights=(1.0, 1.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2, expected=2.0 / 2,
class_id=3, class_id=3,
weights=(2.0, 3.0)) weights=(2.0, 3.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2,
class_id=3,
weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=3.0 / 3, expected=3.0 / 3,
class_id=3, class_id=3,
weights=(3.0, 2.0)) weights=(3.0, 2.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=3.0 / 3,
class_id=3,
weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.3 / 0.3, expected=0.3 / 0.3,
class_id=3, class_id=3,
weights=(0.3, 0.6)) weights=(0.3, 0.6))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.3 / 0.3,
class_id=3,
weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.6 / 0.6, expected=0.6 / 0.6,
class_id=3, class_id=3,
weights=(0.6, 0.3)) weights=(0.6, 0.3))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.6 / 0.6,
class_id=3,
weights=(0.6, 0.3))
# All classes: 2 labels, 2 predictions, 1 correct. # All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, weights=(0.0,)) predictions, labels, k=1, expected=NAN, weights=(0.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=(0.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
def test_three_labels_at_k5_nan(self): def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10): for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id) predictions, labels, k=5, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_three_labels_at_k5_no_predictions(self): def test_three_labels_at_k5_no_predictions(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 8: 1 label, no predictions. # Class 8: 1 label, no predictions.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=8) predictions, labels, k=5, expected=0.0 / 1, class_id=8)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, class_id=8)
def test_three_labels_at_k5(self): def test_three_labels_at_k5(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value( sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 2 labels, both correct. # Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=2) predictions, labels, k=5, expected=2.0 / 2, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect. # Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 1, class_id=5) predictions, labels, k=5, expected=1.0 / 1, class_id=5)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect. # Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=7) predictions, labels, k=5, expected=0.0 / 1, class_id=7)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 6 labels, 3 correct. # All classes: 6 labels, 3 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=3.0 / 6) predictions, labels, k=5, expected=3.0 / 6)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=3.0 / 6)
def test_three_labels_at_k5_some_out_of_range(self): def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) count in denominator.""" """Tests that labels outside the [0, n_classes) count in denominator."""
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue( sp_labels = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
[1, 3]], [1, 3]],
@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5, k=5,
expected=2.0 / 2, expected=2.0 / 2,
class_id=2) class_id=2)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=2.0 / 2,
class_id=2)
# Class 5: 1 label, incorrect. # Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5, k=5,
expected=1.0 / 1, expected=1.0 / 1,
class_id=5) class_id=5)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=1.0 / 1,
class_id=5)
# Class 7: 1 label, incorrect. # Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase):
k=5, k=5,
expected=0.0 / 1, expected=0.0 / 1,
class_id=7) class_id=7)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=0.0 / 1,
class_id=7)
# All classes: 8 labels, 3 correct. # All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8) predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
self._test_sparse_recall_at_top_k(
sp_labels, top_k_predictions, expected=3.0 / 8)
def test_3d_nan(self): def test_3d_nan(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value( sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10): for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id) predictions, labels, k=5, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_3d_no_predictions(self): def test_3d_no_predictions(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value( sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (1, 8): for class_id in (1, 8):
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id) predictions, labels, k=5, expected=0.0, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0, class_id=class_id)
def test_3d(self): def test_3d(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value( labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 4 labels, all correct. # Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=4.0 / 4, class_id=2) predictions, labels, k=5, expected=4.0 / 4, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=4.0 / 4, class_id=2)
# Class 5: 2 labels, both correct. # Class 5: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=5) predictions, labels, k=5, expected=2.0 / 2, class_id=5)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 2, class_id=5)
# Class 7: 2 labels, 1 incorrect. # Class 7: 2 labels, 1 incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 2, class_id=7) predictions, labels, k=5, expected=1.0 / 2, class_id=7)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, class_id=7)
# All classes: 12 labels, 7 correct. # All classes: 12 labels, 7 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=7.0 / 12) predictions, labels, k=5, expected=7.0 / 12)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=7.0 / 12)
def test_3d_ignore_all(self): def test_3d_ignore_all(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value( labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN, expected=NAN,
class_id=class_id, class_id=class_id,
weights=[[0], [0]]) weights=[[0], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=class_id,
weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, predictions,
labels, labels,
@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN, expected=NAN,
class_id=class_id, class_id=class_id,
weights=[[0, 0], [0, 0]]) weights=[[0, 0], [0, 0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=class_id,
weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0], [0]]) predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]]) predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]])
def test_3d_ignore_some(self): def test_3d_ignore_some(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value( labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0, expected=2.0 / 2.0,
class_id=2, class_id=2,
weights=[[1], [0]]) weights=[[1], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2.0,
class_id=2,
weights=[[1], [0]])
# Class 2: 2 labels, both correct. # Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0, expected=2.0 / 2.0,
class_id=2, class_id=2,
weights=[[0], [1]]) weights=[[0], [1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2.0,
class_id=2,
weights=[[0], [1]])
# Class 7: 1 label, correct. # Class 7: 1 label, correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1.0, expected=1.0 / 1.0,
class_id=7, class_id=7,
weights=[[0], [1]]) weights=[[0], [1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1.0,
class_id=7,
weights=[[0], [1]])
# Class 7: 1 label, incorrect. # Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1.0, expected=0.0 / 1.0,
class_id=7, class_id=7,
weights=[[1], [0]]) weights=[[1], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.0 / 1.0,
class_id=7,
weights=[[1], [0]])
# Class 7: 2 labels, 1 correct. # Class 7: 2 labels, 1 correct.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 2.0, expected=1.0 / 2.0,
class_id=7, class_id=7,
weights=[[1, 0], [1, 0]]) weights=[[1, 0], [1, 0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 2.0,
class_id=7,
weights=[[1, 0], [1, 0]])
# Class 7: No labels. # Class 7: No labels.
self._test_streaming_sparse_recall_at_k( self._test_streaming_sparse_recall_at_k(
@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN, expected=NAN,
class_id=7, class_id=7,
weights=[[0, 1], [0, 1]]) weights=[[0, 1], [0, 1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=7,
weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self): def test_sparse_tensor_value(self):
predictions = [[0.1, 0.3, 0.2, 0.4], predictions = [[0.1, 0.3, 0.2, 0.4],

View File

@ -304,6 +304,7 @@ filegroup(
exclude = [ exclude = [
"**/METADATA", "**/METADATA",
"**/OWNERS", "**/OWNERS",
"tools/**",
], ],
), ),
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
@ -351,3 +352,27 @@ tf_kernel_library(
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )
py_binary(
name = "checkpoint_convert",
srcs = ["python/tools/checkpoint_convert.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
py_test(
name = "checkpoint_convert_test",
size = "small",
srcs = ["python/tools/checkpoint_convert_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":checkpoint_convert",
"//tensorflow/python:client_testlib",
],
)

View File

@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2])
g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m) cell = core_rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
self.assertEqual(
["root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
[v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
with self.test_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
return getter(*args, **kwargs)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5),
custom_getter=not_trainable_getter):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
cell = core_rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
self.assertFalse(cell.trainable_variables)
self.assertEqual(
["root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
[v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run(
[g], {x.name: np.array([[1., 1.]]), [g], {x.name: np.array([[1., 1.]]),
@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8]) m = array_ops.zeros([1, 8])
g, out_m = core_rnn_cell_impl.MultiRNNCell( cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell( [core_rnn_cell_impl.BasicLSTMCell(
2, state_is_tuple=False) for _ in range(2)], 2, state_is_tuple=False) for _ in range(2)],
state_is_tuple=False)(x, m) state_is_tuple=False)
g, out_m = cell(x, m)
expected_variable_names = [
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME]
self.assertEqual(
expected_variable_names, [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run(
[g, out_m], [g, out_m],
@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase):
m.name: 0.1 * np.ones([1, 8])}) m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2) self.assertEqual(len(res), 2)
variables = variables_lib.global_variables() variables = variables_lib.global_variables()
self.assertEqual(4, len(variables)) self.assertEqual(expected_variable_names, [v.name for v in variables])
self.assertEquals(variables[0].op.name,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
self.assertEquals(variables[1].op.name,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
self.assertEquals(variables[2].op.name,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
self.assertEquals(variables[3].op.name,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
# The numbers in results were not calculated, this is just a smoke test. # The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem = np.array([[ expected_mem = np.array([[

View File

@ -27,7 +27,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import contextlib
import hashlib import hashlib
import math import math
import numbers import numbers
@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases"
_WEIGHTS_VARIABLE_NAME = "weights" _WEIGHTS_VARIABLE_NAME = "weights"
@contextlib.contextmanager
def _checked_scope(cell, scope, reuse=None, **kwargs):
if reuse is not None:
kwargs["reuse"] = reuse
with vs.variable_scope(scope, **kwargs) as checking_scope:
scope_name = checking_scope.name
if hasattr(cell, "_scope"):
cell_scope = cell._scope # pylint: disable=protected-access
if cell_scope.name != checking_scope.name:
raise ValueError(
"Attempt to reuse RNNCell %s with a different variable scope than "
"its first use. First use of cell was with scope '%s', this "
"attempt is with scope '%s'. Please create a new instance of the "
"cell if you would like it to use a different set of weights. "
"If before you were using: MultiRNNCell([%s(...)] * num_layers), "
"change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). "
"If before you were using the same cell instance as both the "
"forward and reverse cell of a bidirectional RNN, simply create "
"two instances (one for forward, one for reverse). "
"In May 2017, we will start transitioning this cell's behavior "
"to use existing stored weights, if any, when it is called "
"with scope=None (which can lead to silent model degradation, so "
"this error will remain until then.)"
% (cell, cell_scope.name, scope_name, type(cell).__name__,
type(cell).__name__))
else:
weights_found = False
try:
with vs.variable_scope(checking_scope, reuse=True):
vs.get_variable(_WEIGHTS_VARIABLE_NAME)
weights_found = True
except ValueError:
pass
if weights_found and reuse is None:
raise ValueError(
"Attempt to have a second RNNCell use the weights of a variable "
"scope that already has weights: '%s'; and the cell was not "
"constructed as %s(..., reuse=True). "
"To share the weights of an RNNCell, simply "
"reuse it in your second calculation, or create a new one with "
"the argument reuse=True." % (scope_name, type(cell).__name__))
# Everything is OK. Update the cell's scope and yield it.
cell._scope = checking_scope # pylint: disable=protected-access
yield checking_scope
class BasicRNNCell(RNNCell): class BasicRNNCell(RNNCell):
"""The most basic RNN cell.""" """The most basic RNN cell."""

View File

@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access
def _get_concat_variable(name, shape, dtype, num_shards): def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor.""" """Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)

View File

@ -0,0 +1,231 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints using RNNCells to new name convention.
Usage:
python checkpoint_convert [--write_v1_checkpoint] \
'/path/to/checkpoint' '/path/to/new_checkpoint'
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import re
import sys
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
############################################################################
# contrib/rnn/python/ops/core_rnn_cell_impl.py
# BasicRNNCell
('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
# GRUCell
('gru_cell/weights', 'gru_cell/kernel'),
('gru_cell/biases', 'gru_cell/bias'),
('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
('gru_cell/gates/biases', 'gru_cell/gates/bias'),
('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
# BasicLSTMCell
('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
# LSTMCell
('lstm_cell/weights', 'lstm_cell/kernel'),
('lstm_cell/biases', 'lstm_cell/bias'),
('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
# OutputProjectionWrapper
('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
# InputProjectionWrapper
('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/lstm_ops.py
# LSTMBlockFusedCell ??
('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/rnn_cell.py
# LayerNormBasicLSTMCell
('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
# UGRNNCell, not found in g3, but still need it?
('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
('ugrnn_cell/biases', 'ugrnn_cell/bias'),
# NASCell
('nas_rnn/weights', 'nas_rnn/kernel'),
('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
# IntersectionRNNCell
('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
('intersection_rnn_cell/in_projection/weights',
'intersection_rnn_cell/in_projection/kernel'),
('intersection_rnn_cell/in_projection/biases',
'intersection_rnn_cell/in_projection/bias'),
# PhasedLSTMCell
('phased_lstm_cell/mask_gates/weights',
'phased_lstm_cell/mask_gates/kernel'),
('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
('phased_lstm_cell/output_gate/weights',
'phased_lstm_cell/output_gate/kernel'),
('phased_lstm_cell/output_gate/biases',
'phased_lstm_cell/output_gate/bias'),
# AttentionCellWrapper
('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
('attention_cell_wrapper/attn_output_projection/weights',
'attention_cell_wrapper/attn_output_projection/kernel'),
('attention_cell_wrapper/attn_output_projection/biases',
'attention_cell_wrapper/attn_output_projection/bias'),
('attention_cell_wrapper/attention/weights',
'attention_cell_wrapper/attention/kernel'),
('attention_cell_wrapper/attention/biases',
'attention_cell_wrapper/attention/bias'),
])
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
('LSTMCell/W_', 'lstm_cell/weights/part_'),
('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
('GRUCell/W_', 'gru_cell/weights/part_'),
('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
])
def _rnn_name_replacement(var_name):
for pattern in _RNN_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
break
return var_name
def _rnn_name_replacement_sharded(var_name):
for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern,
_RNN_SHARDED_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
return var_name
def _split_sharded_vars(name_shape_map):
"""Split shareded variables.
Args:
name_shape_map: A dict from variable name to variable shape.
Returns:
not_sharded: Names of the non-sharded variables.
sharded: Names of the sharded varibales.
"""
sharded = []
not_sharded = []
for name in name_shape_map:
if re.match(name, '_[0-9]+$'):
if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
sharded.append(name)
else:
not_sharded.append(name)
else:
not_sharded.append(name)
return not_sharded, sharded
def convert_names(checkpoint_from_path,
checkpoint_to_path,
write_v1_checkpoint=False):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with ops.Graph().as_default():
logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
not_sharded, sharded = _split_sharded_vars(name_shape_map)
new_variable_map = {}
conversion_map = {}
for var_name in not_sharded:
new_var_name = _rnn_name_replacement(var_name)
tensor = reader.get_tensor(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
for var_name in sharded:
new_var_name = _rnn_name_replacement_sharded(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
write_version = (saver_pb2.SaverDef.V1
if write_v1_checkpoint else saver_pb2.SaverDef.V2)
saver = saver_lib.Saver(new_variable_map, write_version=write_version)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
saver.save(sess, checkpoint_to_path)
logging.info('Summary:')
logging.info(' Converted %d variable name(s).' % len(new_variable_map))
return new_variable_map, conversion_map
def main(_):
convert_names(
FLAGS.checkpoint_from_path,
FLAGS.checkpoint_to_path,
write_v1_checkpoint=FLAGS.write_v1_checkpoint)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument('checkpoint_from_path', type=str,
help='Path to source checkpoint to be read in.')
parser.add_argument('checkpoint_to_path', type=str,
help='Path to checkpoint to be written out.')
parser.add_argument('--write_v1_checkpoint', action='store_true',
help='Write v1 checkpoint')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,108 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for checkpoint converter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import tempfile
from tensorflow.contrib.rnn.python.tools import checkpoint_convert
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class CheckpointConvertTest(test.TestCase):
def setUp(self):
self._old_ckpt_path = tempfile.mktemp()
self._new_ckpt_path = tempfile.mktemp()
ops.reset_default_graph()
def tearDown(self):
for file_name in glob.glob(self._old_ckpt_path + "*"):
os.remove(file_name)
for file_name in glob.glob(self._new_ckpt_path + "*"):
os.remove(file_name)
def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
variables.Variable(20.0, name=old_name)
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
self.assertItemsEqual(
["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
new_var_map.keys())
self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertItemsEqual(["a"], new_var_map.keys())
self.assertFalse(conversion_map)
def testConversionToV1Succeeds(self):
variables.Variable(10.0, name="a")
variables.Variable(
20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
self.assertItemsEqual(
["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
new_var_map.keys())
self.assertEqual(
{list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
conversion_map)
if __name__ == "__main__":
test.main()

View File

@ -261,7 +261,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -657,7 +657,7 @@ def train(train_op,
if local_init_op == _USE_DEFAULT: if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group( local_init_op = control_flow_ops.group(
tf_variables.local_variables_initializer(), tf_variables.local_variables_initializer(),
data_flow_ops.tables_initializer()) lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance( if sync_optimizer is not None and isinstance(
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):

View File

@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto", "framework/versions.proto",
"lib/core/error_codes.proto", "lib/core/error_codes.proto",
"protobuf/config.proto", "protobuf/config.proto",
"protobuf/cluster.proto",
"protobuf/debug.proto", "protobuf/debug.proto",
"protobuf/queue_runner.proto", "protobuf/queue_runner.proto",
"protobuf/rewriter_config.proto", "protobuf/rewriter_config.proto",
@ -506,6 +507,7 @@ tf_gen_op_libs(
"image_ops", "image_ops",
"io_ops", "io_ops",
"linalg_ops", "linalg_ops",
"lookup_ops",
"logging_ops", "logging_ops",
"math_ops", "math_ops",
"nn_ops", "nn_ops",
@ -582,6 +584,7 @@ cc_library(
":image_ops_op_lib", ":image_ops_op_lib",
":io_ops_op_lib", ":io_ops_op_lib",
":linalg_ops_op_lib", ":linalg_ops_op_lib",
":lookup_ops_op_lib",
":logging_ops_op_lib", ":logging_ops_op_lib",
":math_ops_op_lib", ":math_ops_op_lib",
":nn_ops_op_lib", ":nn_ops_op_lib",
@ -708,6 +711,7 @@ cc_library(
"//tensorflow/core/kernels:image", "//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io", "//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:linalg", "//tensorflow/core/kernels:linalg",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:math", "//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op", "//tensorflow/core/kernels:multinomial_op",

View File

@ -23,8 +23,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
Device::Device(Env* env, const DeviceAttributes& device_attributes, Device::Device(Env* env, const DeviceAttributes& device_attributes)
Allocator* device_allocator)
: DeviceBase(env), device_attributes_(device_attributes) { : DeviceBase(env), device_attributes_(device_attributes) {
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
<< "Invalid device name: " << name(); << "Invalid device name: " << name();

View File

@ -53,8 +53,7 @@ namespace tensorflow {
class Device : public DeviceBase { class Device : public DeviceBase {
public: public:
Device(Env* env, const DeviceAttributes& device_attributes, Device(Env* env, const DeviceAttributes& device_attributes);
Allocator* device_allocator);
~Device() override; ~Device() override;
// Full name of this device (see top comment). // Full name of this device (see top comment).

View File

@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
for (Device* d : devices) { for (Device* d : devices) {
devices_.push_back(d); devices_.push_back(d);
// Register under both the full name and the local name. // Register under the (1) full name, (2) canonical name, and (3) local name.
string full_name = d->name(); string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d; device_map_[CopyToBackingStore(full_name)] = d;
DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
if (parsed_name.has_job && parsed_name.has_replica &&
parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
string canonical_name = DeviceNameUtils::FullName(
parsed_name.job, parsed_name.replica, parsed_name.task,
parsed_name.type, parsed_name.id);
device_map_[CopyToBackingStore(canonical_name)] = d;
}
string lname = DeviceNameUtils::LocalName(d->name()); string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d; device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++; device_type_counts_[d->device_type()]++;
@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
} }
DeviceMgr::~DeviceMgr() { DeviceMgr::~DeviceMgr() {
for (auto p : devices_) delete p; // TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
for (Device* p : devices_) delete p;
} }
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s; Status s;
auto iter = device_map_.find(name); auto iter = device_map_.find(name);
if (iter == device_map_.end()) { if (iter == device_map_.end()) {
std::vector<StringPiece> device_names;
for (auto&& itr : device_map_) {
device_names.push_back(itr.first);
}
LOG(WARNING) << "Unknown device: " << name
<< " all devices: " << str_util::Join(device_names, ", ");
return errors::InvalidArgument(name, " unknown device."); return errors::InvalidArgument(name, " unknown device.");
} }
*device = iter->second; *device = iter->second;

View File

@ -36,6 +36,7 @@ class DeviceMgr {
public: public:
// Takes ownership of each device in 'devices'. // Takes ownership of each device in 'devices'.
// TODO(zhifengc): Other initialization information. // TODO(zhifengc): Other initialization information.
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices); explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr(); ~DeviceMgr();
@ -61,6 +62,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const; int NumDeviceType(const string& type) const;
private: private:
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec; typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_; DeviceVec devices_;

View File

@ -39,7 +39,10 @@ class DeviceSet {
// Set the device designated as the "client". This device // Set the device designated as the "client". This device
// must also be registered via AddDevice(). // must also be registered via AddDevice().
void set_client_device(Device* device) { client_device_ = device; } void set_client_device(Device* device) {
DCHECK(client_device_ == nullptr);
client_device_ = device;
}
// Returns a pointer to the device designated as the "client". // Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; } Device* client_device() const { return client_device_; }

View File

@ -27,8 +27,7 @@ namespace {
static Device* Dev(const char* type, const char* name) { static Device* Dev(const char* type, const char* name) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
: Device(nullptr, attr, nullptr) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
}; };

View File

@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
int gpu_id, const string& physical_device_desc, int gpu_id, const string& physical_device_desc,
Allocator* gpu_allocator, Allocator* cpu_allocator, Allocator* gpu_allocator, Allocator* cpu_allocator,
bool sync_every_op, int32 max_streams) bool sync_every_op, int32 max_streams)
: LocalDevice(options, : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit, memory_limit, locality,
locality, physical_device_desc), physical_device_desc)),
gpu_allocator),
gpu_allocator_(gpu_allocator), gpu_allocator_(gpu_allocator),
cpu_allocator_(cpu_allocator), cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id), gpu_id_(gpu_id),

View File

@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
}; };
LocalDevice::LocalDevice(const SessionOptions& options, LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes, const DeviceAttributes& attributes)
Allocator* device_allocator) : Device(options.env, attributes), owned_tp_info_(nullptr) {
: Device(options.env, attributes, device_allocator),
owned_tp_info_(nullptr) {
// If we're running on the CPU, log warnings if we're not compiled using the // If we're running on the CPU, log warnings if we're not compiled using the
// best flags for performance. // best flags for performance.
port::WarnAboutUnusedCPUFeatures(); port::WarnAboutUnusedCPUFeatures();

View File

@ -33,8 +33,8 @@ struct SessionOptions;
// GPUDevice into more 'process-wide' abstractions. // GPUDevice into more 'process-wide' abstractions.
class LocalDevice : public Device { class LocalDevice : public Device {
public: public:
LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, LocalDevice(const SessionOptions& options,
Allocator* device_allocator); const DeviceAttributes& attributes);
~LocalDevice() override; ~LocalDevice() override;
private: private:

View File

@ -0,0 +1,54 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h"
namespace tensorflow {
// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
Device* underlying,
bool owns_underlying) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
underlying->parsed_name();
CHECK(underlying_parsed_name.has_type);
CHECK(underlying_parsed_name.has_id);
parsed_name.type = underlying_parsed_name.type;
parsed_name.id = underlying_parsed_name.id;
string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
parsed_name.task, parsed_name.type,
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
return new RenamedDevice(underlying, attributes, owns_underlying);
}
RenamedDevice::RenamedDevice(Device* underlying,
const DeviceAttributes& attributes,
bool owns_underlying)
: Device(underlying->env(), attributes),
underlying_(underlying),
owns_underlying_(owns_underlying) {}
RenamedDevice::~RenamedDevice() {
if (owns_underlying_) {
delete underlying_;
}
}
} // namespace tensorflow

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// Wraps a device with a new name, delegating work to the wrapped device.
//
// This class is used to wrap local devices when using clusterspec propagation
// where the name of a particular device may change in the context of a given
// session.
class RenamedDevice : public Device {
public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
bool owns_underlying);
~RenamedDevice() override;
// Below are virtual methods defined on DeviceBase
bool RequiresRecordingAccessedTensors() const override {
return underlying_->RequiresRecordingAccessedTensors();
}
const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
return underlying_->tensorflow_cpu_worker_threads();
}
const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
return underlying_->tensorflow_gpu_device_info();
}
Allocator* GetAllocator(AllocatorAttributes attr) override {
return underlying_->GetAllocator(attr);
}
Allocator* GetStepAllocator(AllocatorAttributes attr,
ResourceMgr* step_resource_manager) override {
return underlying_->GetStepAllocator(attr, step_resource_manager);
}
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice* eigen_sycl_device() const override {
return underlying_->eigen_sycl_device();
}
#endif
PerOpGpuDevice* MakeGpuDevice() override {
return underlying_->MakeGpuDevice();
}
void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
DeviceContext* dc, Allocator* allocator) override {
underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override {
return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
}
// Below are virtual methods defined on Device
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
underlying_->Compute(op_kernel, context);
}
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override {
underlying_->ComputeAsync(op_kernel, context, std::move(done));
}
void ConsumeListOfAccessedTensors(
DeviceContext* context, const TensorReferenceVector& tensors) override {
underlying_->ConsumeListOfAccessedTensors(context, tensors);
}
Status Sync() override { return underlying_->Sync(); }
Status MaybeRewriteGraph(const FunctionDefLibrary& library,
std::unique_ptr<Graph>* graph) override {
return underlying_->MaybeRewriteGraph(library, graph);
}
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override {
return underlying_->FillContextMap(graph, device_context_map);
}
private:
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
bool owns_underlying);
Device* const underlying_;
const bool owns_underlying_;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_

View File

@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
class FakeDevice : public Device { class FakeDevice : public Device {
private: private:
explicit FakeDevice(const DeviceAttributes& device_attributes) explicit FakeDevice(const DeviceAttributes& device_attributes)
: Device(nullptr, device_attributes, nullptr) {} : Device(nullptr, device_attributes) {}
public: public:
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }

View File

@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
const string& name, Bytes memory_limit, const string& name, Bytes memory_limit,
const DeviceLocality& locality, const DeviceLocality& locality,
Allocator* allocator) Allocator* allocator)
: LocalDevice(options, : LocalDevice(options, Device::BuildDeviceAttributes(
Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit, name, DEVICE_CPU, memory_limit, locality)),
locality),
allocator),
allocator_(allocator) {} allocator_(allocator) {}
ThreadPoolDevice::~ThreadPoolDevice() {} ThreadPoolDevice::~ThreadPoolDevice() {}

View File

@ -77,7 +77,6 @@ cc_library(
], ],
deps = [ deps = [
":graph_mgr", ":graph_mgr",
":rendezvous_mgr_interface",
":worker_cache", ":worker_cache",
"//tensorflow/core:master_proto_cc", "//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -92,9 +91,9 @@ cc_library(
deps = [ deps = [
":graph_mgr", ":graph_mgr",
":worker_session", ":worker_session",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
], ],
) )
@ -237,6 +236,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc", "//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
], ],
) )

View File

@ -35,9 +35,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env, BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
const string& worker_name) : worker_env_(worker_env) {}
: worker_env_(worker_env), worker_name_(worker_name) {}
BaseRendezvousMgr::~BaseRendezvousMgr() { BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) { for (auto& p : table_) {
@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
} }
} }
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id); return FindOrCreate(step_id);
} }
@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_); mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id); Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) { if (iter == table_.end()) {
auto rr = Create(step_id, worker_env_, worker_name_); auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first; iter = table_.insert({step_id, rr}).first;
} }
iter->second->Ref(); iter->second->Ref();
@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
} }
} }
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
const string& worker_name,
int64 step_id,
bool tolerate_dup_recv) bool tolerate_dup_recv)
: env_(env), : env_(env),
worker_name_(worker_name),
step_id_(step_id), step_id_(step_id),
local_(NewLocalRendezvous(tolerate_dup_recv)) {} local_(NewLocalRendezvous(tolerate_dup_recv)),
session_(nullptr) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() { BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty()); CHECK(active_.empty());
@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
return device_name.starts_with(worker_name); return device_name.starts_with(worker_name);
} }
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
CHECK_NE(session, nullptr) << "session must not be null!";
std::vector<DeferredCall> deferred_calls;
{
mutex_lock l(mu_);
if (session_ != nullptr) {
if (session_->worker_name == session->worker_name) {
LOG(INFO) << "Skipping rendezvous re-initialization.";
return Status::OK();
}
Status s = errors::Internal(
"Double init! Worker names would have changed from: ",
session_->worker_name, " -> ", session->worker_name);
LOG(WARNING) << s;
return s;
}
session_ = session;
std::swap(deferred_calls, deferred_calls_);
}
for (DeferredCall& call : deferred_calls) {
RecvLocalAsyncInternal(call.parsed, std::move(call.done));
}
return Status::OK();
}
WorkerSession* BaseRemoteRendezvous::session() {
mutex_lock l(mu_);
return session_;
}
bool BaseRemoteRendezvous::is_initialized() {
mutex_lock l(mu_);
return is_initialized_locked();
}
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args, const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) { const Tensor& val, const bool is_dead) {
@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!status_.ok()) return status_; if (!status_.ok()) return status_;
DCHECK(is_initialized_locked());
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name);
} }
if (!IsLocalDevice(worker_name_, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", worker_name_);
} }
// Buffers "val" and "device_context" in local_. // Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead); return local_->Send(parsed, args, val, is_dead);
@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) { bool is_src) {
// Cache session pointer to avoid repeatedly taking & releasing the lock
// (e.g. calling session())
WorkerSession* sess = nullptr;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!status_.ok()) return status_; if (!status_.ok()) return status_;
if (!is_initialized_locked()) {
return errors::Internal("ValidateDevices called before initialization.");
} }
if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) { sess = session_;
}
if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", worker_name_); parsed.FullKey(), " @ ", sess->worker_name);
} }
if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) { if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ", return errors::InvalidArgument("Invalid rendezvous key (dst): ",
parsed.FullKey(), " @ ", worker_name_); parsed.FullKey(), " @ ", sess->worker_name);
} }
return Status::OK(); return Status::OK();
} }
@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args, const Rendezvous::Args& recv_args,
DoneCallback done) { DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
Status s = ValidateDevices(parsed, false /*!is_src*/); Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) { if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false); done(s, Args(), recv_args, Tensor(), false);
@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) { DoneCallback done) {
{
mutex_lock l(mu_);
if (!is_initialized_locked()) {
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
// remote worker) before the RunStep (or PartialRunStep) RPC from the
// master arrives. RecvLocalAsync thus buffers the arguments until after
// the RemoteRendezvous is Initialize()'d, when it completes the
// rendezvous logic. At some point after Initialize() is called, a Tensor
// is produced locally that will then be sent in response to the incoming
// RPC.
DeferredCall call(parsed, std::move(done));
deferred_calls_.push_back(call);
return;
}
}
RecvLocalAsyncInternal(parsed, std::move(done));
}
void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */); Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) { if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false); done(s, Args(), Args(), Tensor(), false);
@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
active_.erase(call); active_.erase(call);
} }
BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
DoneCallback done)
: parsed(parsed), done(std::move(done)) {}
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -59,15 +59,17 @@ class BaseRecvTensorCall;
// RendezvousMgr must have keys generated by Rendezvous::CreateKey(). // RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface { class BaseRendezvousMgr : public RendezvousMgrInterface {
public: public:
explicit BaseRendezvousMgr(const WorkerEnv* worker_env, explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
const string& worker_name);
~BaseRendezvousMgr() override; ~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the // Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the // "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance. // returned Rendezvous instance.
Rendezvous* Find(int64 step_id) override; //
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
RemoteRendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs // Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs. // "done" when the tensor for "key" is produced or an error occurs.
@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
protected: protected:
virtual BaseRemoteRendezvous* Create(int64 step_id, virtual BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env, const WorkerEnv* worker_env) = 0;
const string& worker_name) = 0;
private: private:
// Maps step_id to rendezvous. // Maps step_id to rendezvous.
@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Not owned. // Not owned.
const WorkerEnv* const worker_env_; const WorkerEnv* const worker_env_;
const string worker_name_;
mutex mu_; mutex mu_;
Table table_ GUARDED_BY(mu_); Table table_ GUARDED_BY(mu_);
@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Buffering of Tensor values is delegated to a "local" Rendezvous // Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds // obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers. // functionality to coordinate with remote workers.
class BaseRemoteRendezvous : public Rendezvous { class BaseRemoteRendezvous : public RemoteRendezvous {
public: public:
BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name, BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
int64 step_id, bool tolerate_dup_recv); bool tolerate_dup_recv);
// Upgrades the BaseRemoteRendezvous to full initialization.
Status Initialize(WorkerSession* session) override;
// Forwards to local_, where the Tensor "val" will be buffered and // Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored. // any waiting callback stored.
@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
// Removes "call" from active_ if "call" is in active_. // Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call); void DeregisterCall(BaseRecvTensorCall* call);
WorkerSession* session();
bool is_initialized();
~BaseRemoteRendezvous() override; ~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned. const WorkerEnv* const env_; // Not owned.
const string worker_name_;
const int64 step_id_; const int64 step_id_;
private: private:
@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
// Status given by StartAbort() if any. // Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_); Status status_ GUARDED_BY(mu_);
WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
// Data structures to handle calls when partially initialized.
struct DeferredCall {
const ParsedKey parsed;
DoneCallback done;
DeferredCall(const ParsedKey& parsed, DoneCallback done);
};
std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls. // Active outstanding RecvTensor calls.
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_); gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return session_ != nullptr;
}
// If "is_src" is true, checks that the rendezvous key "parsed"'s // If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the // source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process. // rendezvous key "parsed"'s destination is in this process.
@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
const Rendezvous::Args& out_args, const Tensor& in, const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done); Tensor* out, StatusCallback done);
// Must be called only if fully initialized.
void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
}; };

View File

@ -46,10 +46,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
GraphMgr::GraphMgr(const WorkerEnv* worker_env, GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
RendezvousMgrInterface* rendezvous_mgr) : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
// The default value of sync_on_finish will be flipped soon and this // The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well. // environment variable will be removed as well.
Status status = Status status =
@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
}; };
popts.get_incarnation = [this](const string& name) -> int64 { popts.get_incarnation = [this](const string& name) -> int64 {
Device* device = nullptr; Device* device = nullptr;
Status s = worker_env_->device_mgr->LookupDevice(name, &device); Status s = device_mgr_->LookupDevice(name, &device);
if (s.ok()) { if (s.ok()) {
return device->attributes().incarnation(); return device->attributes().incarnation();
} else { } else {
@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
ExecutionUnit* unit = &(item->units.back()); ExecutionUnit* unit = &(item->units.back());
// Find the device. // Find the device.
Status s = Status s = device_mgr_->LookupDevice(device_name, &unit->device);
worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
if (!s.ok()) { if (!s.ok()) {
// Remove the empty unit from the item as the item destructor wants all // Remove the empty unit from the item as the item destructor wants all
// units to have valid devices. // units to have valid devices.
@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
// Function library runtime. // Function library runtime.
unit->lib = NewFunctionLibraryRuntime( unit->lib = NewFunctionLibraryRuntime(
worker_env_->device_mgr, worker_env_->env, unit->device, device_mgr_, worker_env_->env, unit->device,
subgraph->versions().producer(), item->lib_def, subgraph->versions().producer(), item->lib_def,
graph_options.optimizer_options()); graph_options.optimizer_options());
@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
} }
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in); Status s = SendInputsToRendezvous(rendezvous, in);
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
} }
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out); Status s = RecvOutputsFromRendezvous(rendezvous, out);
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) { StatusCallback done) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out, RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) { [done, rendezvous](const Status s) {
rendezvous->Unref(); rendezvous->Unref();
@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
} }
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, WorkerSession* session,
const ExecutorOpts& /*opts*/,
StepStatsCollector* collector, StepStatsCollector* collector,
CostGraphDef* cost_graph, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return; return;
} }
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
// Sends values specified by the caller. // Sends values specified by the caller.
Status s = SendInputsToRendezvous(rendezvous, in); if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in);
}
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
item->Unref(); item->Unref();
@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
StatusCallback done) { StatusCallback done) {
const int num_units = item->units.size(); const int num_units = item->units.size();
CHECK_GE(num_units, 1); CHECK_GE(num_units, 1);
ScopedStepContainer* step_container = ScopedStepContainer* step_container = new ScopedStepContainer(
new ScopedStepContainer(step_id, [this](const string& name) { step_id,
worker_env_->device_mgr->ClearContainers({name}); [this](const string& name) { device_mgr_->ClearContainers({name}); });
});
// NOTE: Transfer one ref of rendezvous and item. // NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier = ExecutorBarrier* barrier =
new ExecutorBarrier(num_units, rendezvous, new ExecutorBarrier(num_units, rendezvous,

View File

@ -37,6 +37,8 @@ namespace tensorflow {
class ExecutorOpts; class ExecutorOpts;
class StepStatsCollector; class StepStatsCollector;
class RendezvousMgrInterface; class RendezvousMgrInterface;
class DeviceMgr;
struct WorkerSession;
// GraphMgr keeps track of a set of graphs that are registered with a // GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle // TensorFlow worker. Each registered graph is identified by a handle
@ -62,8 +64,7 @@ class RendezvousMgrInterface;
// EXPECT_EQ(out["c"], Tensor({4, 6})); // EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr { class GraphMgr {
public: public:
explicit GraphMgr(const WorkerEnv* worker_env, explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
RendezvousMgrInterface* rendezvous_mgr);
~GraphMgr(); ~GraphMgr();
// Registers a graph. Fills in "handle" // Registers a graph. Fills in "handle"
@ -78,8 +79,8 @@ class GraphMgr {
typedef std::map<string, Tensor> NamedTensors; typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback; typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id, void ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, StepStatsCollector* collector, WorkerSession* session, const ExecutorOpts& opts,
CostGraphDef* cost_graph, StepStatsCollector* collector, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done); const NamedTensors& in, StatusCallback done);
@ -131,7 +132,7 @@ class GraphMgr {
}; };
const WorkerEnv* worker_env_; // Not owned. const WorkerEnv* worker_env_; // Not owned.
RendezvousMgrInterface* rendezvous_mgr_; // Not owned. DeviceMgr* device_mgr_;
CostModelManager cost_model_manager_; CostModelManager cost_model_manager_;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -48,12 +49,17 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
namespace {
const char* const kGrpcProtocol = "grpc://";
} // namespace
Master::Master(MasterEnv* env, double session_gc_seconds) Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env), : env_(env),
last_1000_steps_(1000), last_1000_steps_(1000),
@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) { CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() { SchedClosure([this, req, resp, done]() {
Status status; Status status;
WorkerCacheFactoryOptions worker_cache_factory_options;
string grpc_protocol("grpc");
worker_cache_factory_options.protocol = &grpc_protocol;
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
status = ValidateExternalGraphDefSyntax(req->graph_def()); status = ValidateExternalGraphDefSyntax(req->graph_def());
if (!status.ok()) return; if (!status.ok()) return;
// Ping all the workers and build the list of devices that the
// session will use. // The following 4 variables are set differently, depending on whether this
// session uses a client-provided clusterspec or not.
WorkerCacheInterface* worker_cache = nullptr;
// Note: worker_cache_ptr will be null except if this session is using a
// client-supplied ClusterDef (ClusterSpec propagation).
std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
std::unique_ptr<DeviceSet> device_set;
// TODO(saeta): Convert to std::make_unique when available. // TODO(saeta): Convert to std::make_unique when available.
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices( std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
new std::vector<std::unique_ptr<Device>>()); new std::vector<std::unique_ptr<Device>>());
status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
env_, env_->worker_cache, if (req->config().has_cluster_def()) {
remote_devices.get()); worker_cache_factory_options.cluster_def = &req->config().cluster_def();
// Set the server_def's job_name and task_index fields.
string normalized_string;
string grpc_protocol(kGrpcProtocol);
if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
0) {
normalized_string =
req->target().substr(grpc_protocol.length(), string::npos);
} else {
normalized_string = req->target();
}
for (auto&& job : req->config().cluster_def().job()) {
for (auto&& task : job.tasks()) {
if (task.second == normalized_string) {
if (worker_cache_factory_options.job_name != nullptr) {
status = errors::InvalidArgument(
"Found multiple matching tasks that correspond to "
"to the master. Master target: '",
req->target(), "'. ClusterDef: ",
req->config().cluster_def().ShortDebugString());
LOG(ERROR) << status;
return;
}
if (env_->local_devices[0]->parsed_name().job == job.name() &&
env_->local_devices[0]->parsed_name().task == task.first) {
// TODO(b/37868888): Remove this limitation when resolved
status = errors::InvalidArgument(
"The ClusterSpec names the job and task index to be the same "
"names that were provided when the server booted. This is "
"currently not allowed. Job: ",
job.name(), ", task index: ", task.first);
return;
}
worker_cache_factory_options.job_name = &job.name();
worker_cache_factory_options.task_index = task.first;
}
}
}
// Create the worker cache from the computed server_def.
status = env_->worker_cache_factory(worker_cache_factory_options,
&worker_cache);
if (!status.ok()) return; if (!status.ok()) return;
worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
DeviceNameUtils::ParsedName name = d->parsed_name();
if (name.job == *worker_cache_factory_options.job_name &&
name.task == worker_cache_factory_options.task_index &&
name.type == "CPU") {
device_set->set_client_device(d.get());
}
}
} else {
worker_cache = env_->worker_cache;
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env_->local_devices) {
device_set->AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
device_set->set_client_device(d);
}
num_local_devices++;
}
}
CHECK(device_set->client_device());
SessionOptions options; SessionOptions options;
options.config = req->config(); options.config = req->config();
MasterSession* session =
env_->master_session_factory(options, env_, std::move(remote_devices)); MasterSession* session = env_->master_session_factory(
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
std::move(device_set));
GraphDef* gdef = GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def(); const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
status = session->Create(gdef);
status = session->Create(gdef, worker_cache_factory_options);
if (!status.ok()) { if (!status.ok()) {
session->Close().IgnoreError(); session->Close().IgnoreError();
session->Unref(); session->Unref();

View File

@ -19,17 +19,41 @@ limitations under the License.
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "tensorflow/core/distributed_runtime/master_session.h" #include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
class Device; class Device;
class DeviceSet;
class Env; class Env;
class MasterSession; class MasterSession;
class OpRegistryInterface; class OpRegistryInterface;
class WorkerCacheInterface; class WorkerCacheInterface;
// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
const ClusterDef* cluster_def = nullptr;
const string* job_name = nullptr;
int task_index;
const string* protocol = nullptr;
WorkerCacheFactoryOptions() {}
// Construct from a ServerDef proto.
//
// Note: server_def must outlive WorkerCacheFactoryOptions!
WorkerCacheFactoryOptions(const ServerDef& server_def) {
if (server_def.has_cluster() && !server_def.job_name().empty()) {
cluster_def = &server_def.cluster();
job_name = &server_def.job_name();
task_index = server_def.task_index();
protocol = &server_def.protocol();
}
}
};
// The master environment class, which holds a bag of pointers to // The master environment class, which holds a bag of pointers to
// per-master state. // per-master state.
// //
@ -57,8 +81,14 @@ struct MasterEnv {
// `MasterEnv*` is retained by the caller. // `MasterEnv*` is retained by the caller.
std::function<MasterSession*( std::function<MasterSession*(
SessionOptions, MasterEnv*, SessionOptions, MasterEnv*,
std::unique_ptr<std::vector<std::unique_ptr<Device>>>)> std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
std::unique_ptr<WorkerCacheInterface>,
std::unique_ptr<DeviceSet> device_set)>
master_session_factory; master_session_factory;
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
}; };
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -162,7 +164,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Partitions the graph into subgraphs and registers them on // Partitions the graph into subgraphs and registers them on
// workers. // workers.
Status RegisterPartitions(const PartitionOptions& popts, Status RegisterPartitions(const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib); const FunctionLibraryDefinition& flib_def);
// Runs one step of all partitions. // Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id, Status RunPartitions(const MasterEnv* env, int64 step_id,
@ -273,7 +275,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
}; };
Status MasterSession::ReffedClientGraph::RegisterPartitions( Status MasterSession::ReffedClientGraph::RegisterPartitions(
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) { const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
{ // Ensure register once. { // Ensure register once.
mu_.lock(); mu_.lock();
if (!init_started_) { if (!init_started_) {
@ -292,7 +294,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
graph_defs_for_publishing.push_back(&name_def.second); graph_defs_for_publishing.push_back(&name_def.second);
} }
stats_publisher_->PublishGraphProto(graph_defs_for_publishing); stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs)); s = DoRegisterPartitions(popts, flib_def.ToProto(),
std::move(graph_defs));
} }
mu_.lock(); mu_.lock();
init_result_ = s; init_result_ = s;
@ -527,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_is_partial(is_partial_); c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run); c->req->set_is_last_partial_run(is_last_partial_run);
} }
c->req->set_session_handle(session_handle_);
c->req->set_graph_handle(part.graph_handle); c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id); c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts; *c->req->mutable_exec_opts() = exec_opts;
@ -870,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
// The graph handle may be empty if we failed during partition registration. // The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) { if (!part.graph_handle.empty()) {
Call* c = new Call; Call* c = new Call;
c->req.set_session_handle(session_handle_);
c->req.set_graph_handle(part.graph_handle); c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this` // NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called. // could be deleted before the callback is called.
@ -972,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
MasterSession::MasterSession( MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env, const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory) StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt), : session_opts_(opt),
env_(env), env_(env),
handle_(strings::FpToString(random::New64())), handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)), remote_devs_(std::move(remote_devs)),
worker_cache_(std::move(worker_cache)),
devices_(std::move(device_set)),
stats_publisher_factory_(std::move(stats_publisher_factory)), stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0), graph_version_(0),
run_graphs_(5), run_graphs_(5),
partial_run_graphs_(5) { partial_run_graphs_(5) {
UpdateLastAccessTime(); UpdateLastAccessTime();
CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size(); << " #remote " << remote_devs_->size();
for (auto&& d : *remote_devs_) {
devices_.AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env->local_devices) {
devices_.AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
devices_.set_client_device(d);
}
num_local_devices++;
}
LOG(INFO) << "Start master session " << handle_ LOG(INFO) << "Start master session " << handle_
<< " with config: " << std::endl << " with config: " << std::endl
<< session_opts_.config.DebugString(); << session_opts_.config.DebugString();
@ -1011,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros()); last_access_time_usec_.store(Env::Default()->NowMicros());
} }
Status MasterSession::Create(GraphDef* graph_def) { Status MasterSession::Create(GraphDef* graph_def,
const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.graph_options().place_pruned_graph()) { if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option. // TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the " LOG(WARNING) << "Distributed session does not support the "
@ -1019,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
} }
SimpleGraphExecutionStateOptions options; SimpleGraphExecutionStateOptions execution_options;
options.device_set = &devices_; execution_options.device_set = devices_.get();
options.session_options = &session_opts_; execution_options.session_options = &session_opts_;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
graph_def, options, &execution_state_)); graph_def, execution_options, &execution_state_));
}
if (options.cluster_def != nullptr) {
return CreateWorkerSessions(options);
} }
return Status::OK(); return Status::OK();
} }
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
<< "dynamic cluster membership.";
std::vector<string> worker_names;
worker_cache_->ListWorkers(&worker_names);
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;
// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;
// Request and responses used for a given worker.
CreateWorkerSessionRequest request;
CreateWorkerSessionResponse response;
Status status = Status::OK();
};
BlockingCounter done(worker_names.size());
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
auto cleanup = gtl::MakeCleanup([this, &workers] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
*workers[i].request.mutable_server_def()->mutable_cluster() =
*options.cluster_def;
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
status = errors::Internal("Could not parse name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
if (!name.has_job || !name.has_task) {
status = errors::Internal("Incomplete worker name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
workers[i].request.mutable_server_def()->set_job_name(name.job);
workers[i].request.mutable_server_def()->set_task_index(name.task);
}
for (size_t i = 0; i < worker_names.size(); ++i) {
auto cb = [i, &workers, &done](const Status& s) {
workers[i].status = s;
done.DecrementCount();
};
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
&workers[i].response, cb);
}
done.Wait();
for (size_t i = 0; i < workers.size(); ++i) {
status.Update(workers[i].status);
}
return status;
}
Status MasterSession::Extend(const ExtendSessionRequest* req, Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) { ExtendSessionResponse* resp) {
UpdateLastAccessTime(); UpdateLastAccessTime();
@ -1059,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK(); return Status::OK();
} }
WorkerCacheInterface* MasterSession::get_worker_cache() const {
if (worker_cache_) {
return worker_cache_.get();
}
return env_->worker_cache;
}
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) { ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts); const uint64 hash = HashBuildGraphOptions(opts);
@ -1082,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
<< "\n"; << "\n";
std::unique_ptr<SimpleClientGraph> client_graph; std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph( auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_, handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial, stats_publisher_factory_, execution_state_.get(), is_partial,
env_->worker_cache); worker_cache);
iter = m->insert({hash, entry}).first; iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph"; VLOG(1) << "Preparing to execute new graph";
} }
@ -1161,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
return errors::FailedPrecondition("Session is closed."); return errors::FailedPrecondition("Session is closed.");
} }
++num_running_; ++num_running_;
// Note: all code paths must eventually call MarkRunCompletion()
// in order to appropriate decrement the num_running_ counter.
} }
Status status; Status status;
if (!req.partial_run_handle().empty()) { if (!req.partial_run_handle().empty()) {
@ -1168,15 +1253,17 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
} else { } else {
status = DoRunWithLocalExecution(opts, req, resp); status = DoRunWithLocalExecution(opts, req, resp);
} }
{ return status;
}
// Decrements num_running_ and broadcasts if num_running_ is zero.
void MasterSession::MarkRunCompletion() {
mutex_lock l(mu_); mutex_lock l(mu_);
--num_running_; --num_running_;
if (num_running_ == 0) { if (num_running_ == 0) {
num_running_is_zero_.notify_all(); num_running_is_zero_.notify_all();
} }
} }
return status;
}
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so. // Registers subgraphs if haven't done so.
@ -1187,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
return strings::StrCat(prefix, "_S", next_node_id_++); return strings::StrCat(prefix, "_S", next_node_id_++);
}; };
popts.get_incarnation = [this](const string& name) -> int64 { popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_.FindDeviceByName(name); Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) { if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation; return PartitionOptions::kIllegalIncarnation;
} else { } else {
@ -1214,7 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto())); rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
return Status::OK(); return Status::OK();
} }
@ -1222,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
Status MasterSession::DoPartialRun(CallOptions* opts, Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) { MutableRunStepResponseWrapper* resp) {
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle(); const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr; RunState* run_state = nullptr;
{ {
@ -1320,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
rcg->Ref(); rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata()); req.options(), resp->mutable_metadata());
cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync( rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) { run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s; LOG(ERROR) << "Cleanup partition error: " << s;
} }
rcg->Unref(); rcg->Unref();
MarkRunCompletion();
}); });
mutex_lock l(mu_); mutex_lock l(mu_);
partial_runs_.erase(prun_handle); partial_runs_.erase(prun_handle);
@ -1367,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
Status MasterSession::DoRunWithLocalExecution( Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req, CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) { MutableRunStepResponseWrapper* resp) {
VLOG(2) << "DoRunWithLocalExecution " VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
<< "req: " << req.DebugString();
PerStepState pss; PerStepState pss;
pss.start_micros = Env::Default()->NowMicros(); pss.start_micros = Env::Default()->NowMicros();
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare. // Prepare.
BuildGraphOptions bgopts; BuildGraphOptions bgopts;
@ -1437,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
} }
} }
rcg->Ref(); rcg->Ref();
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s; LOG(ERROR) << "Cleanup partition error: " << s;
} }
rcg->Unref(); rcg->Unref();
MarkRunCompletion();
}); });
return s; return s;
} }

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h" #include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/master.pb.h"
@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
MasterSession( MasterSession(
const SessionOptions& options, const MasterEnv* env, const SessionOptions& options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory); StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(), // Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close(). // Run(), or Close().
// //
// After this method returns, `def` will no longer be valid. // After this method returns, `def` will no longer be valid.
Status Create(GraphDef* def); Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
// Returns the session handle. // Returns the session handle.
const string& handle() const { return handle_; } const string& handle() const { return handle_; }
@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
// The optional session-specific worker cluster.
// TODO(saeta): Convert to std::optional when available.
std::unique_ptr<WorkerCacheInterface> worker_cache_;
// Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
WorkerCacheInterface* get_worker_cache() const;
// The device set used by this session. // The device set used by this session.
DeviceSet devices_; std::unique_ptr<DeviceSet> devices_;
StatsPublisherFactory stats_publisher_factory_; StatsPublisherFactory stats_publisher_factory_;
@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
// Private dtor. The client must call Close(). // Private dtor. The client must call Close().
virtual ~MasterSession(); virtual ~MasterSession();
// Creates sessions on all workers.
//
// If this session is operating using the new ClusterSpec propagation behavior
// call this method in order to propagate the cluster membership to all
// workers.
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
Status StartStep(const BuildGraphOptions& opts, int64* count, Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial); ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
MutableRunStepResponseWrapper* resp); MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp); MutableRunStepResponseWrapper* resp);
void MarkRunCompletion();
void UpdateLastAccessTime(); void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);

View File

@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
const string& InMemoryRunGraphRequest::session_handle() const {
return session_handle_;
}
void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
session_handle_ = handle;
}
const string& InMemoryRunGraphRequest::graph_handle() const { const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_; return graph_handle_;
} }
@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) { if (!proto_version_) {
proto_version_.reset(new RunGraphRequest); proto_version_.reset(new RunGraphRequest);
proto_version_->set_session_handle(session_handle());
proto_version_->set_graph_handle(graph_handle()); proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id()); proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts(); *proto_version_->mutable_exec_opts() = exec_opts();
@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
return *proto_version_; return *proto_version_;
} }
const string& MutableProtoRunGraphRequest::session_handle() const {
return request_.session_handle();
}
void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
request_.set_session_handle(handle);
}
const string& MutableProtoRunGraphRequest::graph_handle() const { const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle(); return request_.graph_handle();
} }
@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
: request_(request) {} : request_(request) {}
const string& ProtoRunGraphRequest::session_handle() const {
return request_->session_handle();
}
const string& ProtoRunGraphRequest::graph_handle() const { const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle(); return request_->graph_handle();
} }

View File

@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
public: public:
virtual ~RunGraphRequestWrapper() {} virtual ~RunGraphRequestWrapper() {}
// The session handle used to register the graph. If empty, a single global
// namespace is used.
virtual const string& session_handle() const = 0;
// REQUIRED: graph_handle must be returned by a RegisterGraph call // REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService. // to the same WorkerService.
virtual const string& graph_handle() const = 0; virtual const string& graph_handle() const = 0;
@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
// See `RunGraphRequestWrapper` above for a description of the fields. // See `RunGraphRequestWrapper` above for a description of the fields.
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public: public:
virtual void set_session_handle(const string& handle) = 0;
virtual void set_graph_handle(const string& handle) = 0; virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0; virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0; virtual ExecutorOpts* mutable_exec_opts() = 0;
@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
public: public:
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;
@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override; const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods. // MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override; void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override; void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override; ExecutorOpts* mutable_exec_opts() override;
@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_last_partial_run(bool is_last_partial_run) override; void set_is_last_partial_run(bool is_last_partial_run) override;
private: private:
string session_handle_;
string graph_handle_; string graph_handle_;
int64 step_id_; int64 step_id_;
ExecutorOpts exec_opts_; ExecutorOpts exec_opts_;
@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public: public:
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;
@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override; const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods. // MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override; void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override; void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override; ExecutorOpts* mutable_exec_opts() override;
@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
ProtoRunGraphRequest(const RunGraphRequest* request); ProtoRunGraphRequest(const RunGraphRequest* request);
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;

View File

@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
class RemoteDevice : public Device { class RemoteDevice : public Device {
public: public:
RemoteDevice(Env* env, const DeviceAttributes& da) RemoteDevice(Env* env, const DeviceAttributes& da)
: Device(env, da, nullptr), : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
local_dev_name_(GetLocalDeviceName(da.name())) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
GetStatusResponse resp; GetStatusResponse resp;
}; };
Call* call = new Call; Call* call = new Call;
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { auto cb = [env, worker_cache, worker_name, done, wi,
call](const Status& status) {
Status s = status;
std::vector<Device*> remote_devices; std::vector<Device*> remote_devices;
if (s.ok()) { auto cleanup = gtl::MakeCleanup(
remote_devices.reserve(call->resp.device_attributes_size()); [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
for (const DeviceAttributes& da : call->resp.device_attributes()) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
}
}
worker_cache->ReleaseWorker(worker_name, wi); worker_cache->ReleaseWorker(worker_name, wi);
done(s, &remote_devices); done(s, &remote_devices);
delete call; delete call;
});
if (s.ok()) {
DeviceNameUtils::ParsedName worker_name_parsed;
if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
!worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
!worker_name_parsed.has_task) {
s = errors::InvalidArgument("Could not parse worker name: ",
worker_name);
LOG(WARNING) << s;
return;
}
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
DeviceNameUtils::ParsedName device_name_parsed;
CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
<< "Device attribute name '" << da.name() << "' could not be "
<< "parsed. Device Attribute: " << da.DebugString();
// Preserve the exact name, if possible.
// TODO(b/37868888): Simplify when legacy device name formats removed.
if (device_name_parsed.job == worker_name_parsed.job &&
device_name_parsed.replica == worker_name_parsed.replica &&
device_name_parsed.task == worker_name_parsed.task) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
} else {
DeviceAttributes da_rewritten = da;
da_rewritten.set_name(DeviceNameUtils::FullName(
worker_name_parsed.job, worker_name_parsed.replica,
worker_name_parsed.task, device_name_parsed.type,
device_name_parsed.id));
auto d = new RemoteDevice(env, da_rewritten);
remote_devices.push_back(d);
}
}
}
}; };
wi->GetStatusAsync(&call->req, &call->resp, cb); wi->GetStatusAsync(&call->req, &call->resp, cb);
} }

View File

@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
struct WorkerSession;
// RemoteRendezvous follow a 2-part initialization. First the objects are
// constructed. Eventually, they will be initialized. Clients of the
// RendezvousMgrInterface must guarantee to call Initialize on the returned
// RemoteRendezvous eventually.
//
// Partially initialized RemoteRendezvous must respect the Rendezvous interface
// (i.e. Send() must never block), however implementations are not expected to
// actually perform the underlying operations until after the RemoteRendezvous
// has been Initialize'd.
class RemoteRendezvous : public Rendezvous {
public:
// Fully construct the RemoteRendezvous.
virtual Status Initialize(WorkerSession* session) = 0;
};
// RendezvousMgr keeps track of a set of local rendezvous instances. // RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr // All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id" // until the tensor is received. Each global unique "step_id"
@ -51,7 +68,10 @@ class RendezvousMgrInterface {
// Returns Rendezvous supporting send and recv among workers in the // Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the // "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance. // returned Rendezvous instance.
virtual Rendezvous* Find(int64 step_id) = 0; //
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
virtual RemoteRendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs // Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs. // "done" when the tensor for "key" is produced or an error occurs.

View File

@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
}; };
// static utility function // static utility function
RendezvousMgrInterface* NewRpcRendezvousMgr( RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
const WorkerEnv* env, const string& worker_name, return new RpcRendezvousMgr(env);
WorkerCacheInterface* worker_cache) {
return new RpcRendezvousMgr(env, worker_name, worker_cache);
} }
} // namespace } // namespace
@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
// TODO(mrry): Refactor the *Env classes so that it is less fiddly // TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them. // to destroy them.
// Shut down all outstanding rendezvous.
delete worker_env_.rendezvous_mgr;
// We must delete graph_mgr before device_mgr, due to shared // We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will // ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful // free all stateless OpKernels, and pass over borrowed stateful
@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
// OpSegments.) // OpSegments.)
if (worker_env_.session_mgr != nullptr) { if (worker_env_.session_mgr != nullptr) {
delete worker_env_.session_mgr; // Deletes graph_mgr's. delete worker_env_.session_mgr; // Deletes graph_mgr's.
} } else {
// Note: session_mgr's legacy_session_ deletes device_mgr now.
delete worker_env_.device_mgr; delete worker_env_.device_mgr;
}
// Do not delete (as these are not owned by the server): // Do not delete (as these are not owned by the server):
// - master_env_.env // - master_env_.env
@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool // - worker_env_.compute_pool
} }
Status GrpcServer::Init(ServiceInitFunction service_func, Status GrpcServer::Init(
RendezvousMgrCreationFunction rendevous_mgr_func) { ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_EQ(state_, NEW); CHECK_EQ(state_, NEW);
master_env_.env = env_; master_env_.env = env_;
@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
"/task:", server_def_.task_index()); "/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices)); &master_env_.local_devices));
worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices); worker_env_.local_devices = master_env_.local_devices;
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);
string unused; string unused;
string default_worker_name; string default_worker_name;
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
} }
WorkerCacheInterface* worker_cache; WorkerCacheInterface* worker_cache;
TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache)); WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
TF_RETURN_IF_ERROR(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache); CHECK_NE(nullptr, worker_cache);
// Set up worker environment. // Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
rendevous_mgr_func == nullptr ?
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr( worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache), std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(server_def, worker_cache); WorkerCacheFactoryOptions options(server_def);
return WorkerCacheFactory(options, worker_cache);
}); });
worker_env_.compute_pool = ComputePool(sess_opts); worker_env_.compute_pool = ComputePool(sess_opts);
@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
master_env_.master_session_factory = master_env_.master_session_factory =
[config]( [config](
SessionOptions options, const MasterEnv* env, SessionOptions options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) { std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config); options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs), return new MasterSession(options, env, std::move(remote_devs),
std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher); CreateNoOpStatsPublisher);
}; };
master_env_.worker_cache_factory =
[this](const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(options, worker_cache);
};
// Provide direct access to the master from in-process clients. // Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(), LocalMaster::Register(target(), master_impl_.get(),
@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK(); return Status::OK();
} }
Status GrpcServer::Init() { Status GrpcServer::Init() { return Init(nullptr, nullptr); }
return Init(nullptr, nullptr);
}
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) { GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) { for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports; std::map<int, string> host_ports;
for (const auto& task : job.tasks()) { for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first]; string& host_port = host_ports[task.first];
@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
task.first, "\": ", host_port, " and ", task.first, "\": ", host_port, " and ",
task.second); task.second);
} }
if (job.name() == server_def.job_name() && if (job.name() == *options.job_name && task.first == options.task_index) {
task.first == server_def.task_index()) {
host_port = strings::StrCat("localhost:", bound_port_); host_port = strings::StrCat("localhost:", bound_port_);
} else { } else {
host_port = task.second; host_port = task.second;
@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
return Status::OK(); return Status::OK();
} }
Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) { WorkerCacheInterface** worker_cache) {
string name_prefix = if (options.job_name == nullptr || options.job_name->empty()) {
strings::StrCat("/job:", server_def.job_name(), "/replica:0", Status s = errors::InvalidArgument(
"/task:", server_def.task_index()); "The master (current machine) is not included in the provided "
"cluster_def. ",
options.cluster_def->DebugString());
LOG(WARNING) << s;
return s;
}
GrpcChannelSpec channel_spec; GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
std::unique_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index);
std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix); const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port; int requested_port;
@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
return ::grpc::InsecureServerCredentials(); return ::grpc::InsecureServerCredentials();
} }
ChannelCreationFunction GrpcServer::GetChannelCreationFunction( ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
const ServerDef& server_def) const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being // We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function // returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);

View File

@ -37,9 +37,7 @@ class GrpcWorker;
class Master; class Master;
// function that creates a RendezvousMgr. // function that creates a RendezvousMgr.
typedef std::function<RendezvousMgrInterface*( typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
const WorkerEnv*, const std::string& worker_name,
WorkerCacheInterface* worker_cache)>
RendezvousMgrCreationFunction; RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to // function that registers a service to the server. The service needs to
@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
protected: protected:
Status Init(ServiceInitFunction service_func, Status Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func); const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init(); Status Init();
@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const; const ServerDef& server_def) const;
virtual ChannelCreationFunction GetChannelCreationFunction( virtual ChannelCreationFunction GetChannelCreationFunction() const;
const ServerDef& server_def) const;
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session. // Creates a WorkerCacheInterface for a session.
Status WorkerCacheFactory(const ServerDef& server_def, Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache); WorkerCacheInterface** worker_cache);
// Parses a ServerDef into a GrpcChannelSpec. // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
Status ParseChannelSpec(const ServerDef& server_def, Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec); GrpcChannelSpec* channel_spec);
// Returns the port to which this server is bound. // Returns the port to which this server is bound.

View File

@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
/* static */ /* static */
Status GrpcSession::Create(const SessionOptions& options, Status GrpcSession::Create(const SessionOptions& options,
std::unique_ptr<GrpcSession>* out_session) { std::unique_ptr<GrpcSession>* out_session) {
std::unique_ptr<GrpcSession> ret(new GrpcSession(options)); std::unique_ptr<GrpcSession> session(new GrpcSession(options));
std::unique_ptr<MasterInterface> master; std::unique_ptr<MasterInterface> master;
// For testing, we enable the client to disable the use of the local // For testing, we enable the client to disable the use of the local
// master registry, so that the RPC stack is exercised. // master registry, so that the RPC stack is exercised.
@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
options.target.substr(kSchemePrefixLength), &master_channel)); options.target.substr(kSchemePrefixLength), &master_channel));
master.reset(NewGrpcMaster(master_channel)); master.reset(NewGrpcMaster(master_channel));
} }
ret->SetRemoteMaster(std::move(master)); session->SetRemoteMaster(std::move(master));
*out_session = std::move(ret); *out_session = std::move(session);
return Status::OK(); return Status::OK();
} }
@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
CreateSessionRequest req; CreateSessionRequest req;
*req.mutable_config() = options_.config; *req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph; *req.mutable_graph_def() = graph;
req.set_target(options_.target);
ReEncodeConsts(req.mutable_graph_def()); ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp; CreateSessionResponse resp;
Status s = master_->CreateSession(call_options, &req, &resp); Status s = master_->CreateSession(call_options, &req, &resp);

View File

@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// completes, and we may decide to bound some of the request // completes, and we may decide to bound some of the request
// types. // types.
ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false);
@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(GetStatus, false);
} }
void CreateWorkerSessionHandler(
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
call) {
Schedule([this, call]() {
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(CreateWorkerSession, false);
}
void CleanupAllHandler( void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) { WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() { Schedule([this, call]() {
@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response, ::grpc::ByteBuffer* response,
StatusCallback done) { StatusCallback done) {
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
const string& key = request->rendezvous_key(); const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed; Rendezvous::ParsedKey parsed;
@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
// of execution of the callback lambda body below, an RPC // of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous. // cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
session->rendezvous_mgr->RecvLocalAsync( env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed, step_id, parsed,
[opts, response, done, src_dev](const Status& status, [opts, response, done, src_dev](const Status& status,
const Rendezvous::Args& send_args, const Rendezvous::Args& send_args,

View File

@ -38,9 +38,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous { class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public: public:
RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name, RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
WorkerCacheInterface* cache, int64 step_id) : BaseRemoteRendezvous(env, step_id, false) {}
: BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
protected: protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private: private:
~RpcRemoteRendezvous() override {} ~RpcRemoteRendezvous() override {}
WorkerCacheInterface* const cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
}; };
@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
return call_freelist; return call_freelist;
} }
// A private cache that wraps worker_cache and allows reuse of
// WorkerInterface objects.
class WorkerFreeListCache : public WorkerCacheInterface {
public:
explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
~WorkerFreeListCache() {
for (auto p : workers_) {
wrapped_->ReleaseWorker(p.first, p.second.worker);
}
}
void ListWorkers(std::vector<string>* workers) const override {
wrapped_->ListWorkers(workers);
}
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
if (p != workers_.end()) {
return p->second.worker;
}
WorkerState state;
state.worker = wrapped_->CreateWorker(target);
if (state.worker != nullptr) {
workers_.insert(std::make_pair(target, state));
}
return state.worker;
}
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
wrapped_->GetDeviceLocalityAsync(device, locality, done);
}
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
void ClearLogs() override { wrapped_->ClearLogs(); }
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
return wrapped_->RetrieveLogs(step_id, ss);
}
private:
WorkerCacheInterface* wrapped_;
// Information kept per created WorkerInterface.
struct WorkerState {
WorkerInterface* worker;
// TODO(jeff,sanjay): Add reference count if we support eviction.
};
// TODO(jeff,sanjay): Eviction when the map becomes too big.
mutex mu_;
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
};
void RpcRemoteRendezvous::RecvFromRemoteAsync( void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) { DoneCallback done) {
CHECK(is_initialized());
Status s; Status s;
// Prepare a RecvTensor call that can handle being aborted. // Prepare a RecvTensor call that can handle being aborted.
@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device, s = errors::Internal(parsed.src_device,
" is invalid remote source device."); " is invalid remote source device.");
} }
WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); WorkerSession* sess = session();
WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) { if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_); s = errors::Internal("No worker known as ", call->src_worker_);
} }
Device* dst_device; Device* dst_device;
if (s.ok()) { if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
} }
if (!s.ok()) { if (!s.ok()) {
get_call_freelist()->Release(call, cache_); if (rwi != nullptr) {
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
}
get_call_freelist()->Release(call, sess->worker_cache.get());
done(s, Args(), recv_args, Tensor{}, false); done(s, Args(), recv_args, Tensor{}, false);
return; return;
} }
@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// current status should be bad. // current status should be bad.
Status s = call->status(); Status s = call->status();
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
cache_->ReleaseWorker(call->src_worker_, call->wi_); session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
call->wi_ = nullptr; call->wi_ = nullptr;
get_call_freelist()->Release(call, cache_); get_call_freelist()->Release(call, session()->worker_cache.get());
Unref(); Unref();
}); });
} }
} // namespace } // namespace
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env, RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
const string& worker_name, : BaseRendezvousMgr(env) {}
WorkerCacheInterface* worker_cache)
: BaseRendezvousMgr(env, worker_name),
cache_(new WorkerFreeListCache(worker_cache)) {}
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env, const WorkerEnv* worker_env) {
const string& worker_name) { return new RpcRemoteRendezvous(worker_env, step_id);
return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
step_id);
} }
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
namespace tensorflow { namespace tensorflow {
class DeviceMgr;
// RendezvousMgr keeps track of a set of local rendezvous instances. // RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr // All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id" // until the tensor is received. Each global unique "step_id"
@ -44,17 +44,12 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey. // RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr { class RpcRendezvousMgr : public BaseRendezvousMgr {
public: public:
explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name, explicit RpcRendezvousMgr(const WorkerEnv* env);
WorkerCacheInterface* worker_cache);
protected: protected:
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
const string& session_name) override;
private: private:
// Private cache_ that allows us to reuse WorkerInterface objects.
std::unique_ptr<WorkerCacheInterface> cache_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
}; };

View File

@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache), : cache_(new DummyWorkerCache),
worker_session_("/job:mnist/replica:1/task:2", worker_session_("/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_), std::unique_ptr<WorkerCacheInterface>(cache_),
std::unique_ptr<RendezvousMgrInterface>(), std::unique_ptr<DeviceMgr>(),
std::unique_ptr<GraphMgr>()), std::unique_ptr<GraphMgr>()),
rmgr_(&env, worker_session_.worker_name, cache_) { rmgr_(&env) {
env.env = Env::Default(); env.env = Env::Default();
} }
@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
"/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort(). { // Explicit Abort().
const int64 step_id = 123; const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
SchedClosure([this, rendez]() { SchedClosure([this, rendez]() {
env.env->SleepForMicroseconds(100 * 1000); env.env->SleepForMicroseconds(100 * 1000);
@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING); Tensor val(DT_STRING);
bool val_dead = false; bool val_dead = false;
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
} }
{ // Cleanup causes Abort(). { // Cleanup causes Abort().
const int64 step_id = 321; const int64 step_id = 321;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
SchedClosure([this, step_id]() { SchedClosure([this, step_id]() {
env.env->SleepForMicroseconds(100 * 1000); env.env->SleepForMicroseconds(100 * 1000);
@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING); Tensor val(DT_STRING);
bool val_dead = false; bool val_dead = false;
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
} }
} }
@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
const int64 step_id = 123; const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
"/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
args.device_context = dc; args.device_context = dc;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
} }
{ {

View File

@ -17,8 +17,9 @@ limitations under the License.
#include <utility> #include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow { namespace tensorflow {
@ -26,23 +27,12 @@ namespace tensorflow {
SessionMgr::SessionMgr( SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name, WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache, std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: SessionMgr(
worker_env, default_worker_name, std::move(default_worker_cache),
default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory) WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env), : worker_env_(worker_env),
legacy_session_( legacy_session_(default_worker_name, std::move(default_worker_cache),
default_worker_name, std::move(default_worker_cache), std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr),
std::unique_ptr<GraphMgr>( std::unique_ptr<GraphMgr>(
new GraphMgr(worker_env, default_rendezvous_mgr))), new GraphMgr(worker_env, worker_env->device_mgr))),
worker_cache_factory_(std::move(worker_cache_factory)) {} worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
Status SessionMgr::CreateSession(const string& session, Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) { const ServerDef& server_def) {
mutex_lock l(mu_); mutex_lock l(mu_);
if (session.empty()) {
return errors::InvalidArgument("Session must be non-empty.");
}
const string worker_name = WorkerNameFromServerDef(server_def); const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr; WorkerCacheInterface* worker_cache = nullptr;
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr( std::vector<Device*> renamed_devices;
new RpcRendezvousMgr(worker_env_, worker_name, worker_cache)); for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(
RenamedDevice::NewRenamedDevice(worker_name, d, false));
}
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
std::unique_ptr<GraphMgr> graph_mgr( std::unique_ptr<GraphMgr> graph_mgr(
new GraphMgr(worker_env_, rendezvous_mgr.get())); new GraphMgr(worker_env_, device_mgr.get()));
std::unique_ptr<WorkerSession> worker_session(new WorkerSession( std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache), worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr), std::move(graph_mgr))); std::move(device_mgr), std::move(graph_mgr)));
sessions_.insert(std::make_pair(session, std::move(worker_session))); sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK(); return Status::OK();
@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
if (it != sessions_.end()) { if (it != sessions_.end()) {
sessions_.erase(it); sessions_.erase(it);
} }
std::set<string> graph_handles;
for (auto graph_handle_it = sessions_by_graph_handle_.begin();
graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
if (graph_handle_it->second == session) {
graph_handles.insert(graph_handle_it->first);
graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
if (graph_handle_it == sessions_by_graph_handle_.end()) break;
}
}
for (auto step_id_it = graphs_by_step_id_.begin();
step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
step_id_it = graphs_by_step_id_.erase(step_id_it);
if (step_id_it == graphs_by_step_id_.end()) break;
}
}
return Status::OK(); return Status::OK();
} }
@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; } WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
const string& graph_handle) {
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it == sessions_by_graph_handle_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForSessionUnlocked(it->second);
}
}
WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
const string& graph_handle) {
mutex_lock l(mu_);
return WorkerSessionForGraphHandleUnlocked(graph_handle);
}
WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it == graphs_by_step_id_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForGraphHandleUnlocked(it->second);
}
}
void SessionMgr::AssociateGraphWithSession(const string& session,
const string& graph_handle) {
mutex_lock l(mu_);
sessions_by_graph_handle_[graph_handle] = session;
}
void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
mutex_lock l(mu_);
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it != sessions_by_graph_handle_.end()) {
sessions_by_graph_handle_.erase(it);
}
}
void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id) {
mutex_lock l(mu_);
graphs_by_step_id_[step_id] = graph_handle;
}
void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it != graphs_by_step_id_.end()) {
graphs_by_step_id_.erase(it);
}
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -30,6 +30,8 @@ struct WorkerEnv;
// SessionMgr keeps track of information related to a given session. // SessionMgr keeps track of information related to a given session.
// //
// SessionMgr runs on the workers.
//
// SessionMgr is threadsafe. // SessionMgr is threadsafe.
class SessionMgr { class SessionMgr {
public: public:
@ -39,7 +41,6 @@ class SessionMgr {
explicit SessionMgr( explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name, WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache, std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory); WorkerCacheFactory worker_cache_factory);
~SessionMgr() {} ~SessionMgr() {}
@ -50,49 +51,36 @@ class SessionMgr {
WorkerSession* WorkerSessionForSession(const string& session); WorkerSession* WorkerSessionForSession(const string& session);
WorkerSession* LegacySession(); WorkerSession* LegacySession();
// Locates the worker session for a given graph handle
WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
void AssociateGraphWithSession(const string& session,
const string& graph_handle);
void DisassociateGraphFromSession(const string& graph_handle);
// Locates a worker session for a given step id
WorkerSession* WorkerSessionForStepId(const int64 step_id);
void AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id);
void DisassociateStepIdFromGraph(const int64 step_id);
Status DeleteSession(const string& session); Status DeleteSession(const string& session);
static string WorkerNameFromServerDef(const ServerDef& server_def); static string WorkerNameFromServerDef(const ServerDef& server_def);
private: private:
// Private constructor to work around std::unique_ptr ownership issues.
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
const WorkerEnv* const worker_env_; // Not owned. const WorkerEnv* const worker_env_; // Not owned.
// A note about destruction:
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
// OpKernels, which are also held in their respective devices'
// OpSegments.)
//
// legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
// that sessions_'s WorkerSessions are deleted (which do not own the
// underlying devices, but instead own RenamedDevices) before
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
// device_mgr is deleted after WorkerSession's graph_mgr.
WorkerSession legacy_session_; WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_; const WorkerCacheFactory worker_cache_factory_;
WorkerSession* WorkerSessionForSessionUnlocked(const string& session) WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
EXCLUSIVE_LOCKS_REQUIRED(mu_); EXCLUSIVE_LOCKS_REQUIRED(mu_);
WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_; mutex mu_;
// A map from session identifier to internal session structure. // A map from session identifier to internal session structure.
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_); std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
// A map from graph handles to the session that they belong to.
std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
// A map from globally-unique step id's to the corresponding graph handles.
std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest() SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0", : mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), std::unique_ptr<WorkerCacheInterface>(),
std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
&env_, "/job:mnist/replica:0/task:0", nullptr)),
factory_), factory_),
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {} legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }
TEST_F(SessionMgrTest, AssociateGraphWithSession) { TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def; ServerDef server_def;
string session_handle = "test_session_handle"; string session_handle = "";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_EQ(mgr_.LegacySession(), session);
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }
TEST_F(SessionMgrTest, AssociateStepWithGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(session, step_session);
ASSERT_EQ(graph_session, step_session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, WorkerNameFromServerDef) { TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def; ServerDef server_def;
server_def.set_job_name("worker"); server_def.set_job_name("worker");

View File

@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
Status s = session->graph_mgr->Register( Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(), request->session_handle(), request->graph_def(), request->graph_options(),
request->debug_options(), response->mutable_graph_handle()); request->debug_options(), response->mutable_graph_handle());
if (s.ok()) {
env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
response->graph_handle());
}
done(s); done(s);
} }
@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response, DeregisterGraphResponse* response,
StatusCallback done) { StatusCallback done) {
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Deregister(request->graph_handle()); Status s = session->graph_mgr->Deregister(request->graph_handle());
env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
done(s); done(s);
} }
@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
} }
void Worker::AbortStep(int64 step_id) { void Worker::AbortStep(int64 step_id) {
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root // Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this // cause may return first back to the client instead of this
@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id); TRACEPRINTF("RunGraph: %lld", step_id);
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); env_->session_mgr->WorkerSessionForSession(request->session_handle());
env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
GraphMgr::NamedTensors in; GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out); Status s = PrepareRunGraph(request, &in, out);
@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
} }
CostGraphDef* cost_graph = response->mutable_cost_graph(); CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync( session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, request->exec_opts(), collector, request->graph_handle(), step_id, session, request->exec_opts(),
cost_graph, cm, in, collector, cost_graph, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts, [this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) { done](Status s) {
if (s.ok()) { if (s.ok()) {
@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const string& graph_handle = request->graph_handle(); const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id); TRACEPRINTF("PartialRunGraph: %lld", step_id);
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(graph_handle); env_->session_mgr->WorkerSessionForSession(request->session_handle());
env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
GraphMgr::NamedTensors in; GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out); Status s = PrepareRunGraph(request, &in, out);
@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
[cm]() { cm->StartCancel(); }); [cm]() { cm->StartCancel(); });
} }
session->graph_mgr->ExecuteAsync( session->graph_mgr->ExecuteAsync(
graph_handle, step_id, request->exec_opts(), nullptr /* collector */, graph_handle, step_id, session, request->exec_opts(),
nullptr /* cost_graph */, cm, in, nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, token, graph_handle, step_id, cm](Status s) { [this, token, graph_handle, step_id, cm](Status s) {
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response, CleanupGraphResponse* response,
StatusCallback done) { StatusCallback done) {
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); env_->rendezvous_mgr->Cleanup(step_id);
session->rendezvous_mgr->Cleanup(step_id);
done(Status::OK()); done(Status::OK());
} }
@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) { Device** src_dev) {
// Figures out which device the tensor is hosted on. // Figures out which device the tensor is hosted on.
TF_RETURN_IF_ERROR( string local_name = DeviceNameUtils::LocalName(parsed.src_device);
env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect? // Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#include <vector>
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -24,8 +25,10 @@ namespace thread {
class ThreadPool; class ThreadPool;
} // namespace thread } // namespace thread
class Device;
class DeviceMgr; class DeviceMgr;
class Env; class Env;
class RendezvousMgrInterface;
class SessionMgr; class SessionMgr;
// The worker environment class, which holds a bag of pointers to // The worker environment class, which holds a bag of pointers to
@ -38,10 +41,18 @@ struct WorkerEnv {
// session_mgr encapsulates state for each session. // session_mgr encapsulates state for each session.
SessionMgr* session_mgr = nullptr; SessionMgr* session_mgr = nullptr;
// The local devices of this worker. Devices are owned by the device_mgr.
//
// REQUIRES: !local_devices.empty().
std::vector<Device*> local_devices;
// device_mgr manages local devices (cpu and gpu). The WorkerService // device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices. // is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr; DeviceMgr* device_mgr = nullptr;
// A set of rendezvous keyed by step ids.
RendezvousMgrInterface* rendezvous_mgr = nullptr;
// A pool of threads for scheduling compute work. // A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr; thread::ThreadPool* compute_pool = nullptr;
}; };

Some files were not shown because too many files have changed in this diff Show More