Merge commit for internal changes
This commit is contained in:
commit
15b8f3d65c
2
configure
vendored
2
configure
vendored
@ -385,7 +385,7 @@ fi
|
|||||||
|
|
||||||
# Append CC optimization flags to bazel.rc
|
# Append CC optimization flags to bazel.rc
|
||||||
for opt in $CC_OPT_FLAGS; do
|
for opt in $CC_OPT_FLAGS; do
|
||||||
write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt'
|
write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
|
||||||
done
|
done
|
||||||
|
|
||||||
# Run the gen_git_source to create links where bazel can track dependencies for
|
# Run the gen_git_source to create links where bazel can track dependencies for
|
||||||
|
@ -58,6 +58,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/cc/saved_model:loader",
|
"//tensorflow/cc/saved_model:loader",
|
||||||
"//tensorflow/cc:gradients",
|
"//tensorflow/cc:gradients",
|
||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
|
"//tensorflow/cc:grad_ops",
|
||||||
"//tensorflow/cc:scope_internal",
|
"//tensorflow/cc:scope_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -91,6 +91,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":array_grad",
|
":array_grad",
|
||||||
":math_grad",
|
":math_grad",
|
||||||
|
":nn_grad",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
|
|||||||
const DeviceType& jit_device_name,
|
const DeviceType& jit_device_name,
|
||||||
perftools::gputools::Platform* platform,
|
perftools::gputools::Platform* platform,
|
||||||
Allocator* xla_allocator)
|
Allocator* xla_allocator)
|
||||||
: LocalDevice(options, attrs, xla_allocator),
|
: LocalDevice(options, attrs),
|
||||||
device_ordinal_(device_ordinal),
|
device_ordinal_(device_ordinal),
|
||||||
jit_device_name_(jit_device_name),
|
jit_device_name_(jit_device_name),
|
||||||
xla_allocator_(xla_allocator),
|
xla_allocator_(xla_allocator),
|
||||||
|
@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
|
|||||||
options,
|
options,
|
||||||
Device::BuildDeviceAttributes(
|
Device::BuildDeviceAttributes(
|
||||||
"", type, Bytes(256 << 20), DeviceLocality(),
|
"", type, Bytes(256 << 20), DeviceLocality(),
|
||||||
strings::StrCat("device: XLA compilation device ", type.type())),
|
strings::StrCat("device: XLA compilation device ", type.type()))),
|
||||||
cpu_allocator()),
|
|
||||||
allocator_(new XlaCompilationAllocator()) {}
|
allocator_(new XlaCompilationAllocator()) {}
|
||||||
|
|
||||||
XlaCompilationDevice::~XlaCompilationDevice() {}
|
XlaCompilationDevice::~XlaCompilationDevice() {}
|
||||||
|
@ -668,6 +668,14 @@ class ComputationBuilder {
|
|||||||
// then Build() should be used instead.
|
// then Build() should be used instead.
|
||||||
Computation BuildAndNoteError();
|
Computation BuildAndNoteError();
|
||||||
|
|
||||||
|
// Returns the first error that was encountered while building the
|
||||||
|
// computation. When an error is encountered, by default we return a vacuous
|
||||||
|
// ComputationDataHandle and inform the user of the error that occurred while
|
||||||
|
// building the computation when they make a final call to Build().
|
||||||
|
//
|
||||||
|
// See also set_die_immediately_on_error().
|
||||||
|
Status first_error() const { return first_error_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using PopulateLiteral = std::function<void(Literal*)>;
|
using PopulateLiteral = std::function<void(Literal*)>;
|
||||||
|
|
||||||
|
@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name,
|
|||||||
if (&argument == retval) {
|
if (&argument == retval) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
compute_function_->setDoesNotAlias(argument.getArgNo() + 1);
|
compute_function_->addAttribute(argument.getArgNo() + 1,
|
||||||
|
llvm::Attribute::NoAlias);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
|
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
|
||||||
|
@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
|
|||||||
ir_emitter_context_->buffer_assignment().GetTempAllocation()) {
|
ir_emitter_context_->buffer_assignment().GetTempAllocation()) {
|
||||||
kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size());
|
kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size());
|
||||||
}
|
}
|
||||||
kernel->setDoesNotAlias(temp_buffer_arg_no + 1);
|
kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias);
|
||||||
|
|
||||||
// Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
|
// Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
|
||||||
// treats it as a CUDA kernel.
|
// treats it as a CUDA kernel.
|
||||||
|
@ -705,7 +705,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
|||||||
CHECK(ShapeUtil::IsArray(instruction->shape()) &&
|
CHECK(ShapeUtil::IsArray(instruction->shape()) &&
|
||||||
ShapeUtil::IsArray(operand->shape()));
|
ShapeUtil::IsArray(operand->shape()));
|
||||||
|
|
||||||
if (instruction->IsElementwiseOnOperand(operand_no) &&
|
if ((instruction->IsElementwiseOnOperand(operand_no) ||
|
||||||
|
InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) &&
|
||||||
!ShapeUtil::IsScalar(operand->shape()) &&
|
!ShapeUtil::IsScalar(operand->shape()) &&
|
||||||
ShapeUtil::Rank(operand->shape()) ==
|
ShapeUtil::Rank(operand->shape()) ==
|
||||||
ShapeUtil::Rank(instruction->shape())) {
|
ShapeUtil::Rank(instruction->shape())) {
|
||||||
|
@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This method can be overriden to mark instructions as requiring the operands
|
||||||
|
// to have the same layout as the result, for performance or correctness. This
|
||||||
|
// will propagate constraints through the instruction from the result into the
|
||||||
|
// operands.
|
||||||
|
virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
|
||||||
|
const HloInstruction* instruction) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Construct contraints and assign layouts to all instructions in the
|
// Construct contraints and assign layouts to all instructions in the
|
||||||
// computation satisfying the given ComputationLayout. Layouts constraints are
|
// computation satisfying the given ComputationLayout. Layouts constraints are
|
||||||
// added, then propagated until all LogicalBuffers in the computation are
|
// added, then propagated until all LogicalBuffers in the computation are
|
||||||
|
@ -244,8 +244,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
}
|
}
|
||||||
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
|
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"cannot concatenate arrays with different ranks: %lld vs %lld",
|
"Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
|
||||||
ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape));
|
"(%s)",
|
||||||
|
ShapeUtil::Rank(*arg_shape),
|
||||||
|
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
|
||||||
|
ShapeUtil::HumanString(*shape).c_str());
|
||||||
}
|
}
|
||||||
if (arg_shape->element_type() != shape->element_type()) {
|
if (arg_shape->element_type() != shape->element_type()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
|
@ -118,6 +118,7 @@ set(tf_proto_text_srcs
|
|||||||
"tensorflow/core/framework/types.proto"
|
"tensorflow/core/framework/types.proto"
|
||||||
"tensorflow/core/framework/versions.proto"
|
"tensorflow/core/framework/versions.proto"
|
||||||
"tensorflow/core/lib/core/error_codes.proto"
|
"tensorflow/core/lib/core/error_codes.proto"
|
||||||
|
"tensorflow/core/protobuf/cluster.proto"
|
||||||
"tensorflow/core/protobuf/config.proto"
|
"tensorflow/core/protobuf/config.proto"
|
||||||
"tensorflow/core/protobuf/debug.proto"
|
"tensorflow/core/protobuf/debug.proto"
|
||||||
"tensorflow/core/protobuf/rewriter_config.proto"
|
"tensorflow/core/protobuf/rewriter_config.proto"
|
||||||
|
@ -22,6 +22,7 @@ set(tf_op_lib_names
|
|||||||
"image_ops"
|
"image_ops"
|
||||||
"io_ops"
|
"io_ops"
|
||||||
"linalg_ops"
|
"linalg_ops"
|
||||||
|
"lookup_ops"
|
||||||
"logging_ops"
|
"logging_ops"
|
||||||
"math_ops"
|
"math_ops"
|
||||||
"nn_ops"
|
"nn_ops"
|
||||||
|
@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator")
|
|||||||
add_python_module("tensorflow/python/estimator/export")
|
add_python_module("tensorflow/python/estimator/export")
|
||||||
add_python_module("tensorflow/python/estimator/inputs")
|
add_python_module("tensorflow/python/estimator/inputs")
|
||||||
add_python_module("tensorflow/python/estimator/inputs/queues")
|
add_python_module("tensorflow/python/estimator/inputs/queues")
|
||||||
|
add_python_module("tensorflow/python/feature_column")
|
||||||
add_python_module("tensorflow/python/framework")
|
add_python_module("tensorflow/python/framework")
|
||||||
add_python_module("tensorflow/python/grappler")
|
add_python_module("tensorflow/python/grappler")
|
||||||
add_python_module("tensorflow/python/kernel_tests")
|
add_python_module("tensorflow/python/kernel_tests")
|
||||||
@ -596,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops")
|
|||||||
GENERATE_PYTHON_OP_LIB("io_ops")
|
GENERATE_PYTHON_OP_LIB("io_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("linalg_ops")
|
GENERATE_PYTHON_OP_LIB("linalg_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("logging_ops")
|
GENERATE_PYTHON_OP_LIB("logging_ops")
|
||||||
|
GENERATE_PYTHON_OP_LIB("lookup_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("nn_ops")
|
GENERATE_PYTHON_OP_LIB("nn_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("parsing_ops")
|
GENERATE_PYTHON_OP_LIB("parsing_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("random_ops")
|
GENERATE_PYTHON_OP_LIB("random_ops")
|
||||||
|
@ -710,25 +710,6 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "identity_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["python/kernel_tests/bijectors/identity_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":bijectors_py",
|
|
||||||
":distributions_py",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"@six_archive//:six",
|
|
||||||
"//tensorflow/contrib/linalg:linalg_py",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "inline_test",
|
name = "inline_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -25,6 +25,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops.binomial import *
|
from tensorflow.contrib.distributions.python.ops.binomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
|
||||||
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
|
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
|
||||||
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
||||||
@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
|
|||||||
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
||||||
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
||||||
from tensorflow.contrib.distributions.python.ops.sample_stats import *
|
from tensorflow.contrib.distributions.python.ops.sample_stats import *
|
||||||
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
|
|
||||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||||
from tensorflow.python.ops.distributions.bernoulli import *
|
from tensorflow.python.ops.distributions.bernoulli import *
|
||||||
from tensorflow.python.ops.distributions.beta import *
|
from tensorflow.python.ops.distributions.beta import *
|
||||||
from tensorflow.python.ops.distributions.categorical import *
|
from tensorflow.python.ops.distributions.categorical import *
|
||||||
from tensorflow.python.ops.distributions.conditional_distribution import *
|
|
||||||
from tensorflow.python.ops.distributions.dirichlet import *
|
from tensorflow.python.ops.distributions.dirichlet import *
|
||||||
from tensorflow.python.ops.distributions.dirichlet_multinomial import *
|
from tensorflow.python.ops.distributions.dirichlet_multinomial import *
|
||||||
from tensorflow.python.ops.distributions.distribution import *
|
from tensorflow.python.ops.distributions.distribution import *
|
||||||
@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import *
|
|||||||
from tensorflow.python.ops.distributions.multinomial import *
|
from tensorflow.python.ops.distributions.multinomial import *
|
||||||
from tensorflow.python.ops.distributions.normal import *
|
from tensorflow.python.ops.distributions.normal import *
|
||||||
from tensorflow.python.ops.distributions.student_t import *
|
from tensorflow.python.ops.distributions.student_t import *
|
||||||
|
from tensorflow.python.ops.distributions.transformed_distribution import *
|
||||||
from tensorflow.python.ops.distributions.uniform import *
|
from tensorflow.python.ops.distributions.uniform import *
|
||||||
|
|
||||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||||
|
@ -23,9 +23,9 @@ import itertools
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
|
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,12 +20,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
|
from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
|
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
|
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
|
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,11 +19,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.distributions import gamma as gamma_lib
|
from tensorflow.python.ops.distributions import gamma as gamma_lib
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
|
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,11 +19,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.distributions import gamma as gamma_lib
|
from tensorflow.python.ops.distributions import gamma as gamma_lib
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform
|
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,9 +21,9 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
|
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
|
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
|
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
|
||||||
|
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
rng = np.random.RandomState(42)
|
rng = np.random.RandomState(42)
|
||||||
|
@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
|
|||||||
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.identity import *
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
|
||||||
@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
|
|||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
|
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
|
||||||
from tensorflow.python.ops.distributions.bijector import *
|
from tensorflow.python.ops.distributions.bijector import *
|
||||||
|
from tensorflow.python.ops.distributions.identity_bijector import Identity
|
||||||
|
|
||||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||||
|
|
||||||
|
@ -1,29 +0,0 @@
|
|||||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Identity bijector."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import *
|
|
||||||
# pylint: enable=wildcard-import
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
|
||||||
|
|
||||||
_allowed_symbols = ["Identity"]
|
|
||||||
|
|
||||||
remove_undocumented(__name__, _allowed_symbols)
|
|
@ -17,9 +17,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
from tensorflow.contrib.distributions.python.ops import conditional_distribution
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.distributions import conditional_distribution
|
from tensorflow.python.ops.distributions import transformed_distribution
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib import linalg
|
from tensorflow.contrib import linalg
|
||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.distributions import kullback_leibler
|
from tensorflow.python.ops.distributions import kullback_leibler
|
||||||
from tensorflow.python.ops.distributions import normal
|
from tensorflow.python.ops.distributions import normal
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import logistic
|
from tensorflow.contrib.distributions.python.ops import logistic
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
|
||||||
# Bijectors must be directly imported because `remove_undocumented` prevents
|
# Bijectors must be directly imported because `remove_undocumented` prevents
|
||||||
# individual file imports.
|
# individual file imports.
|
||||||
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
|
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
|
||||||
@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops.distributions import distribution
|
from tensorflow.python.ops.distributions import distribution
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,13 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.distributions import student_t
|
from tensorflow.python.ops.distributions import student_t
|
||||||
|
from tensorflow.python.ops.distributions import transformed_distribution
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,6 +108,7 @@ tf_custom_op_py_library(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/feature_column",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers
|
|||||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
||||||
|
from tensorflow.python.feature_column import feature_column as fc_core
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -1497,7 +1499,10 @@ def _real_valued_var_len_column(column_name,
|
|||||||
is_sparse)
|
is_sparse)
|
||||||
|
|
||||||
|
|
||||||
class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
|
class _RealValuedColumn(
|
||||||
|
_FeatureColumn,
|
||||||
|
fc_core._DenseColumn, # pylint: disable=protected-access
|
||||||
|
collections.namedtuple(
|
||||||
"_RealValuedColumn",
|
"_RealValuedColumn",
|
||||||
["column_name", "dimension", "default_value", "dtype", "normalizer"])):
|
["column_name", "dimension", "default_value", "dtype", "normalizer"])):
|
||||||
"""Represents a real valued feature column also known as continuous features.
|
"""Represents a real valued feature column also known as continuous features.
|
||||||
@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
def _to_dense_tensor(self, input_tensor):
|
def _to_dense_tensor(self, input_tensor):
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _variable_shape(self):
|
||||||
|
return tensor_shape.TensorShape((self.dimension))
|
||||||
|
|
||||||
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||||
|
del weight_collections
|
||||||
|
del trainable
|
||||||
|
return inputs.get(self)
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
return math_ops.to_float(
|
||||||
|
self._normalized_input_tensor(inputs.get(self.name)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
|
||||||
def real_valued_column(column_name,
|
def real_valued_column(column_name,
|
||||||
dimension=1,
|
dimension=1,
|
||||||
|
@ -27,14 +27,15 @@ from tensorflow.contrib.layers.python.layers import feature_column
|
|||||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.example import feature_pb2
|
from tensorflow.core.example import feature_pb2
|
||||||
|
from tensorflow.python.feature_column import feature_column as fc_core
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import partitioned_variables
|
from tensorflow.python.ops import partitioned_variables
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -223,7 +224,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertEqual(len(output), 1)
|
self.assertEqual(len(output), 1)
|
||||||
self.assertIn(keys_sparse, output)
|
self.assertIn(keys_sparse, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
|
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
|
||||||
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
|
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
|
||||||
self.assertAllEqual(output[keys_sparse].indices.eval(),
|
self.assertAllEqual(output[keys_sparse].indices.eval(),
|
||||||
@ -241,7 +242,7 @@ class TransformerTest(test.TestCase):
|
|||||||
output = feature_column_ops._Transformer(features).transform(keys_sparse)
|
output = feature_column_ops._Transformer(features).transform(keys_sparse)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
# While the input is a dense Tensor, the output should be a SparseTensor.
|
# While the input is a dense Tensor, the output should be a SparseTensor.
|
||||||
self.assertIsInstance(output, sparse_tensor.SparseTensor)
|
self.assertIsInstance(output, sparse_tensor.SparseTensor)
|
||||||
self.assertEqual(output.dtype, dtypes.int64)
|
self.assertEqual(output.dtype, dtypes.int64)
|
||||||
@ -310,7 +311,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertIn(weighted_ids, output)
|
self.assertIn(weighted_ids, output)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
|
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
|
||||||
ids_tensor.dense_shape.eval())
|
ids_tensor.dense_shape.eval())
|
||||||
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
|
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
|
||||||
@ -340,7 +341,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertEqual(len(output), 1)
|
self.assertEqual(len(output), 1)
|
||||||
self.assertIn(vocab_sparse, output)
|
self.assertIn(vocab_sparse, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
||||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||||
@ -362,7 +363,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertEqual(len(output), 1)
|
self.assertEqual(len(output), 1)
|
||||||
self.assertIn(vocab_sparse, output)
|
self.assertIn(vocab_sparse, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
||||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||||
@ -386,7 +387,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertEqual(len(output), 1)
|
self.assertEqual(len(output), 1)
|
||||||
self.assertIn(vocab_sparse, output)
|
self.assertIn(vocab_sparse, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
||||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||||
@ -408,7 +409,7 @@ class TransformerTest(test.TestCase):
|
|||||||
self.assertEqual(len(output), 1)
|
self.assertEqual(len(output), 1)
|
||||||
self.assertIn(vocab_sparse, output)
|
self.assertIn(vocab_sparse, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
||||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||||
@ -600,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
one_hot_column, embedding_column, real_valued_column])
|
one_hot_column, embedding_column, real_valued_column])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
|
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
|
||||||
|
|
||||||
def testRealValuedColumn(self):
|
def testRealValuedColumn(self):
|
||||||
@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[real_valued])
|
[real_valued])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self.assertAllClose(output.eval(), features["price"].eval())
|
self.assertAllClose(output.eval(), features["price"].eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllClose(output.eval(),
|
||||||
|
fc_core.make_input_layer(features,
|
||||||
|
[real_valued]).eval())
|
||||||
|
|
||||||
def testRealValuedColumnWithMultiDimensions(self):
|
def testRealValuedColumnWithMultiDimensions(self):
|
||||||
real_valued = feature_column.real_valued_column("price", 2)
|
real_valued = feature_column.real_valued_column("price", 2)
|
||||||
@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[real_valued])
|
[real_valued])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self.assertAllClose(output.eval(), features["price"].eval())
|
self.assertAllClose(output.eval(), features["price"].eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllClose(output.eval(),
|
||||||
|
fc_core.make_input_layer(features,
|
||||||
|
[real_valued]).eval())
|
||||||
|
|
||||||
def testRealValuedColumnSparse(self):
|
def testRealValuedColumnSparse(self):
|
||||||
sparse_real_valued = feature_column._real_valued_var_len_column(
|
sparse_real_valued = feature_column._real_valued_var_len_column(
|
||||||
@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[real_valued])
|
[real_valued])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self.assertAllClose(output.eval(), features["price"].eval() - 2)
|
self.assertAllClose(output.eval(), features["price"].eval() - 2)
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllClose(output.eval(),
|
||||||
|
fc_core.make_input_layer(features,
|
||||||
|
[real_valued]).eval())
|
||||||
|
|
||||||
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
|
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
|
||||||
real_valued = feature_column.real_valued_column(
|
real_valued = feature_column.real_valued_column(
|
||||||
@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[real_valued])
|
[real_valued])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self.assertAllClose(output.eval(), features["price"].eval() - 2)
|
self.assertAllClose(output.eval(), features["price"].eval() - 2)
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllClose(output.eval(),
|
||||||
|
fc_core.make_input_layer(features,
|
||||||
|
[real_valued]).eval())
|
||||||
|
|
||||||
def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
|
def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
|
||||||
bucket = feature_column.bucketized_column(
|
bucket = feature_column.bucketized_column(
|
||||||
@ -697,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[one_hot_column])
|
[one_hot_column])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
|
||||||
@ -715,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
|
||||||
@ -733,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
|
||||||
@ -767,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[one_hot_sparse])
|
[one_hot_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([3, 10], output.eval().shape)
|
self.assertAllEqual([3, 10], output.eval().shape)
|
||||||
|
|
||||||
def testEmbeddingColumnSucceedsForDNN(self):
|
def testEmbeddingColumnSucceedsForDNN(self):
|
||||||
@ -874,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||||
|
|
||||||
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
|
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
|
||||||
@ -897,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||||
|
|
||||||
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
|
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
|
||||||
@ -948,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Error creating input layer for column: ids_weighted_by_weights"):
|
"Error creating input layer for column: ids_weighted_by_weights"):
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
|
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
|
||||||
|
|
||||||
def testCrossedColumnFailsForDNN(self):
|
def testCrossedColumnFailsForDNN(self):
|
||||||
@ -1055,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
# score: (sum of weights)
|
# score: (sum of weights)
|
||||||
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
|
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
|
||||||
|
|
||||||
@ -1293,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input = sess.run(model_input_tensor)
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
expected_input_shape = np.array([4, 3, 4])
|
expected_input_shape = np.array([4, 3, 4])
|
||||||
@ -1327,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input = sess.run(model_input_tensor)
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
expected_input_shape = np.array([4, 3, hash_buckets])
|
expected_input_shape = np.array([4, 3, hash_buckets])
|
||||||
@ -1357,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input = sess.run(model_input_tensor)
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||||
@ -1386,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input = sess.run(model_input_tensor)
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||||
@ -1416,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
embedding_weights)
|
embedding_weights)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
|
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
|
||||||
|
|
||||||
expected_input_shape = [4, 3, embedding_dimension]
|
expected_input_shape = [4, 3, embedding_dimension]
|
||||||
@ -1483,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
model_input = sess.run(model_input_tensor)
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
expected_input_shape = [
|
expected_input_shape = [
|
||||||
@ -1564,7 +1581,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [weighted_ids], num_outputs=5)
|
features, [weighted_ids], num_outputs=5)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
|
||||||
def testWeightedSparseColumnWithDenseInputTensor(self):
|
def testWeightedSparseColumnWithDenseInputTensor(self):
|
||||||
@ -1580,7 +1597,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
|
||||||
def testCrossedColumn(self):
|
def testCrossedColumn(self):
|
||||||
@ -1634,7 +1651,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [movies], num_outputs=1))
|
features, [movies], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.initialize_all_variables().run()
|
variables_lib.initialize_all_variables().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[movies][0]
|
weights = column_to_variable[movies][0]
|
||||||
self.assertEqual(weights.get_shape(), (3, 1))
|
self.assertEqual(weights.get_shape(), (3, 1))
|
||||||
@ -1709,7 +1726,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [age, language], num_outputs=1))
|
features, [age, language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||||
|
|
||||||
@ -1749,7 +1766,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
self.assertEqual(len(variables), 1)
|
self.assertEqual(len(variables), 1)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||||
|
|
||||||
@ -1813,7 +1830,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [weighted_language], num_outputs=1))
|
features, [weighted_language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||||
|
|
||||||
@ -1841,7 +1858,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [language], num_outputs=1))
|
features, [language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
# score: 0.1 + language_weight['hindi'] + language_weight['english']
|
# score: 0.1 + language_weight['hindi'] + language_weight['english']
|
||||||
sess.run(bias.assign([0.1]))
|
sess.run(bias.assign([0.1]))
|
||||||
@ -1864,7 +1881,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [movies], num_outputs=1))
|
features, [movies], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[movies][0]
|
weights = column_to_variable[movies][0]
|
||||||
self.assertEqual(weights.get_shape(), (15, 1))
|
self.assertEqual(weights.get_shape(), (15, 1))
|
||||||
@ -1898,7 +1915,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country_language], num_outputs=1))
|
features, [country_language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[country_language][0]
|
weights = column_to_variable[country_language][0]
|
||||||
sess.run(weights.assign(weights + 0.4))
|
sess.run(weights.assign(weights + 0.4))
|
||||||
@ -1922,7 +1939,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [language_language], num_outputs=1))
|
features, [language_language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[language_language][0]
|
weights = column_to_variable[language_language][0]
|
||||||
sess.run(weights.assign(weights + 0.4))
|
sess.run(weights.assign(weights + 0.4))
|
||||||
@ -1955,7 +1972,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country_language], num_outputs=1))
|
features, [country_language], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[country_language][0]
|
weights = column_to_variable[country_language][0]
|
||||||
sess.run(weights.assign(weights + 0.4))
|
sess.run(weights.assign(weights + 0.4))
|
||||||
@ -1996,7 +2013,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
scope=scope))
|
scope=scope))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertEqual(2, len(column_to_variable[country]))
|
self.assertEqual(2, len(column_to_variable[country]))
|
||||||
self.assertEqual(3, len(column_to_variable[language]))
|
self.assertEqual(3, len(column_to_variable[language]))
|
||||||
@ -2033,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country, age, incomes], num_outputs=1))
|
features, [country, age, incomes], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
incomes_weights = column_to_variable[incomes][0]
|
incomes_weights = column_to_variable[incomes][0]
|
||||||
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
|
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
|
||||||
@ -2069,7 +2086,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country, age, height, incomes], num_outputs=5))
|
features, [country, age, height, incomes], num_outputs=5))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
height_weights = column_to_variable[height][0]
|
height_weights = column_to_variable[height][0]
|
||||||
sess.run(
|
sess.run(
|
||||||
@ -2099,7 +2116,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [bucket], num_outputs=1))
|
features, [bucket], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
|
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
|
||||||
[0.4]]))
|
[0.4]]))
|
||||||
@ -2127,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [bucket, country], num_outputs=1))
|
features, [bucket, country], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
# dimension = 2, bucket_size = 4, num_classes = 1
|
# dimension = 2, bucket_size = 4, num_classes = 1
|
||||||
sess.run(column_to_variable[bucket][0].assign(
|
sess.run(column_to_variable[bucket][0].assign(
|
||||||
@ -2156,7 +2173,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [bucket, country], num_outputs=5))
|
features, [bucket, country], num_outputs=5))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
# dimension = 2, bucket_size = 4, num_classes = 5
|
# dimension = 2, bucket_size = 4, num_classes = 5
|
||||||
sess.run(column_to_variable[bucket][0].assign(
|
sess.run(column_to_variable[bucket][0].assign(
|
||||||
@ -2192,7 +2209,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country_price], num_outputs=1))
|
features, [country_price], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[country_price][0]
|
weights = column_to_variable[country_price][0]
|
||||||
sess.run(weights.assign(weights + 0.4))
|
sess.run(weights.assign(weights + 0.4))
|
||||||
@ -2231,7 +2248,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [country_language_price], num_outputs=1))
|
features, [country_language_price], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[country_language_price][0]
|
weights = column_to_variable[country_language_price][0]
|
||||||
sess.run(weights.assign(weights + 0.4))
|
sess.run(weights.assign(weights + 0.4))
|
||||||
@ -2255,7 +2272,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [product], num_outputs=1))
|
features, [product], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
product_weights = column_to_variable[product][0]
|
product_weights = column_to_variable[product][0]
|
||||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||||
@ -2270,7 +2287,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [product], num_outputs=1))
|
features, [product], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
product_weights = column_to_variable[product][0]
|
product_weights = column_to_variable[product][0]
|
||||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||||
@ -2285,7 +2302,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [product], num_outputs=1))
|
features, [product], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
product_weights = column_to_variable[product][0]
|
product_weights = column_to_variable[product][0]
|
||||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||||
self.assertAllClose(output.eval(), [[0.6], [0.7]])
|
self.assertAllClose(output.eval(), [[0.6], [0.7]])
|
||||||
@ -2306,7 +2323,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [product], num_outputs=1))
|
features, [product], num_outputs=1))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
product_weights = column_to_variable[product][0]
|
product_weights = column_to_variable[product][0]
|
||||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||||
@ -2318,7 +2335,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [feature_column.real_valued_column("age")], num_outputs=3)
|
features, [feature_column.real_valued_column("age")], num_outputs=3)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
sess.run(bias.assign([0.1, 0.2, 0.3]))
|
sess.run(bias.assign([0.1, 0.2, 0.3]))
|
||||||
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
|
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
|
||||||
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
|
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
|
||||||
@ -2332,7 +2349,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [column], num_outputs=3))
|
features, [column], num_outputs=3))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
weights = column_to_variable[column][0]
|
weights = column_to_variable[column][0]
|
||||||
self.assertEqual(weights.get_shape(), (1, 3))
|
self.assertEqual(weights.get_shape(), (1, 3))
|
||||||
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
|
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
|
||||||
@ -2356,7 +2373,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [column], num_outputs=3))
|
features, [column], num_outputs=3))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
weights = column_to_variable[column][0]
|
weights = column_to_variable[column][0]
|
||||||
self.assertEqual(weights.get_shape(), (5, 3))
|
self.assertEqual(weights.get_shape(), (5, 3))
|
||||||
sess.run(
|
sess.run(
|
||||||
@ -2382,7 +2399,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [column], num_outputs=3))
|
features, [column], num_outputs=3))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[column][0]
|
weights = column_to_variable[column][0]
|
||||||
self.assertEqual(weights.get_shape(), (5, 3))
|
self.assertEqual(weights.get_shape(), (5, 3))
|
||||||
@ -2422,7 +2439,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [column], num_outputs=3))
|
features, [column], num_outputs=3))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[column][0]
|
weights = column_to_variable[column][0]
|
||||||
self.assertEqual(weights.get_shape(), (5, 3))
|
self.assertEqual(weights.get_shape(), (5, 3))
|
||||||
@ -2451,7 +2468,7 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features, [column], num_outputs=3))
|
features, [column], num_outputs=3))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
weights = column_to_variable[column][0]
|
weights = column_to_variable[column][0]
|
||||||
self.assertEqual(weights.get_shape(), (5, 3))
|
self.assertEqual(weights.get_shape(), (5, 3))
|
||||||
@ -2516,7 +2533,7 @@ class ParseExampleTest(test.TestCase):
|
|||||||
self.assertIn(bucket, output)
|
self.assertIn(bucket, output)
|
||||||
self.assertIn(wire_cast, output)
|
self.assertIn(wire_cast, output)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
|
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
|
||||||
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
|
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
|
||||||
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
|
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
|
||||||
|
@ -46,7 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
|
|||||||
Args:
|
Args:
|
||||||
uniform: Whether to use uniform or normal distributed random initialization.
|
uniform: Whether to use uniform or normal distributed random initialization.
|
||||||
seed: A Python integer. Used to create random seeds. See
|
seed: A Python integer. Used to create random seeds. See
|
||||||
@{set_random_seed} for behavior.
|
@{tf.set_random_seed} for behavior.
|
||||||
dtype: The data type. Only floating point types are supported.
|
dtype: The data type. Only floating point types are supported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -96,7 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
|
|||||||
mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
|
mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
|
||||||
uniform: Whether to use uniform or normal distributed random initialization.
|
uniform: Whether to use uniform or normal distributed random initialization.
|
||||||
seed: A Python integer. Used to create random seeds. See
|
seed: A Python integer. Used to create random seeds. See
|
||||||
@{set_random_seed} for behavior.
|
@{tf.set_random_seed} for behavior.
|
||||||
dtype: The data type. Only floating point types are supported.
|
dtype: The data type. Only floating point types are supported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
|||||||
self.context_feature_columns)
|
self.context_feature_columns)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(data_flow_ops.tables_initializer())
|
sess.run(lookup_ops.tables_initializer())
|
||||||
sequence_input_val = sess.run(sequence_input)
|
sequence_input_val = sess.run(sequence_input)
|
||||||
expected_shape = np.array([
|
expected_shape = np.array([
|
||||||
3, # expected batch size
|
3, # expected batch size
|
||||||
@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
|||||||
# Obtain values of activations and final state.
|
# Obtain values of activations and final state.
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(data_flow_ops.tables_initializer())
|
sess.run(lookup_ops.tables_initializer())
|
||||||
activations, final_state = sess.run([activations_t, final_state_t])
|
activations, final_state = sess.run([activations_t, final_state_t])
|
||||||
|
|
||||||
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
|
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
|
||||||
|
@ -57,7 +57,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import resources
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator):
|
|||||||
init_op = control_flow_ops.group(
|
init_op = control_flow_ops.group(
|
||||||
variables.local_variables_initializer(),
|
variables.local_variables_initializer(),
|
||||||
resources.initialize_resources(resources.shared_resources()),
|
resources.initialize_resources(resources.shared_resources()),
|
||||||
data_flow_ops.tables_initializer())
|
lookup_ops.tables_initializer())
|
||||||
|
|
||||||
# Perform the export
|
# Perform the export
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.losses import losses as losses_lib
|
from tensorflow.python.ops.losses import losses as losses_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
train_op_fn=head_lib.no_op_train_fn,
|
train_op_fn=head_lib.no_op_train_fn,
|
||||||
logits=((1., 0., 0.), (0., 0., 1.),))
|
logits=((1., 0., 0.), (0., 0., 1.),))
|
||||||
with session.Session():
|
with session.Session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[0, 2],
|
[0, 2],
|
||||||
model_fn_ops.predictions["classes"].eval())
|
model_fn_ops.predictions["classes"].eval())
|
||||||
@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
train_op_fn=head_lib.no_op_train_fn,
|
train_op_fn=head_lib.no_op_train_fn,
|
||||||
logits=((1., 0., 0.), (0., 0., 1.),))
|
logits=((1., 0., 0.), (0., 0., 1.),))
|
||||||
with session.Session():
|
with session.Session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[b"key0", b"key2"],
|
[b"key0", b"key2"],
|
||||||
model_fn_ops.predictions["classes"].eval())
|
model_fn_ops.predictions["classes"].eval())
|
||||||
@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
train_op_fn=head_lib.no_op_train_fn,
|
train_op_fn=head_lib.no_op_train_fn,
|
||||||
logits=((1., 0., 0.),))
|
logits=((1., 0., 0.),))
|
||||||
with session.Session():
|
with session.Session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertIsNone(model_fn_ops.train_op)
|
self.assertIsNone(model_fn_ops.train_op)
|
||||||
_assert_no_variables(self)
|
_assert_no_variables(self)
|
||||||
_assert_summary_tags(self, ["loss"])
|
_assert_summary_tags(self, ["loss"])
|
||||||
@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
train_op_fn=head_lib.no_op_train_fn,
|
train_op_fn=head_lib.no_op_train_fn,
|
||||||
logits=((0., 0., 1.),))
|
logits=((0., 0., 1.),))
|
||||||
with session.Session():
|
with session.Session():
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertIsNone(model_fn_ops.train_op)
|
self.assertIsNone(model_fn_ops.train_op)
|
||||||
_assert_no_variables(self)
|
_assert_no_variables(self)
|
||||||
_assert_summary_tags(self, ["loss"])
|
_assert_summary_tags(self, ["loss"])
|
||||||
|
@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(data_flow_ops.initialize_all_tables())
|
sess.run(lookup_ops.tables_initializer())
|
||||||
features_val = sess.run(features_by_time)
|
features_val = sess.run(features_by_time)
|
||||||
self.assertAllEqual(expected, features_val)
|
self.assertAllEqual(expected, features_val)
|
||||||
|
|
||||||
@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(data_flow_ops.initialize_all_tables())
|
sess.run(lookup_ops.tables_initializer())
|
||||||
actual_sequence, actual_context = sess.run(
|
actual_sequence, actual_context = sess.run(
|
||||||
[sequence, context])
|
[sequence, context])
|
||||||
assert_equal(expected_sequence, actual_sequence)
|
assert_equal(expected_sequence, actual_sequence)
|
||||||
|
@ -647,6 +647,10 @@ class Experiment(object):
|
|||||||
if _sentinel is not None:
|
if _sentinel is not None:
|
||||||
raise ValueError("_call_train should be called with keyword args only")
|
raise ValueError("_call_train should be called with keyword args only")
|
||||||
|
|
||||||
|
# Estimator in core cannot work with monitors. We need to convert them
|
||||||
|
# to hooks. For Estimator in contrib, it is converted internally. So, it is
|
||||||
|
# safe to convert for both cases.
|
||||||
|
hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
|
||||||
if self._core_estimator_used:
|
if self._core_estimator_used:
|
||||||
return self._estimator.train(input_fn=input_fn,
|
return self._estimator.train(input_fn=input_fn,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
|
@ -24,7 +24,6 @@ import time
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import experiment
|
from tensorflow.contrib.learn.python.learn import experiment
|
||||||
from tensorflow.contrib.learn.python.learn import monitors
|
|
||||||
from tensorflow.contrib.learn.python.learn import run_config
|
from tensorflow.contrib.learn.python.learn import run_config
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||||
@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase):
|
|||||||
self.assertEqual(1, est.eval_count)
|
self.assertEqual(1, est.eval_count)
|
||||||
self.assertEqual(1, len(est.monitors))
|
self.assertEqual(1, len(est.monitors))
|
||||||
self.assertEqual([noop_hook], est.eval_hooks)
|
self.assertEqual([noop_hook], est.eval_hooks)
|
||||||
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
|
self.assertTrue(isinstance(est.monitors[0],
|
||||||
|
session_run_hook.SessionRunHook))
|
||||||
|
|
||||||
def test_train_hooks_extend_does_not_mutate_input_hooks(self):
|
def test_train_hooks_extend_does_not_mutate_input_hooks(self):
|
||||||
for est in self._estimators_for_tests():
|
for est in self._estimators_for_tests():
|
||||||
@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase):
|
|||||||
self.assertEqual(1, est.export_count)
|
self.assertEqual(1, est.export_count)
|
||||||
self.assertEqual(1, len(est.monitors))
|
self.assertEqual(1, len(est.monitors))
|
||||||
self.assertEqual([noop_hook], est.eval_hooks)
|
self.assertEqual([noop_hook], est.eval_hooks)
|
||||||
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
|
self.assertTrue(isinstance(est.monitors[0],
|
||||||
|
session_run_hook.SessionRunHook))
|
||||||
|
|
||||||
def test_train_and_evaluate_with_no_eval_during_training(self):
|
def test_train_and_evaluate_with_no_eval_during_training(self):
|
||||||
for est in self._estimators_for_tests():
|
for est in self._estimators_for_tests():
|
||||||
|
@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import resources
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -429,11 +429,14 @@ def _get_ready_op():
|
|||||||
|
|
||||||
|
|
||||||
def _get_local_init_op():
|
def _get_local_init_op():
|
||||||
|
"""Returns the local init ops to initialize tables and local variables."""
|
||||||
local_init_op = _get_first_op_from_collection(
|
local_init_op = _get_first_op_from_collection(
|
||||||
ops.GraphKeys.LOCAL_INIT_OP)
|
ops.GraphKeys.LOCAL_INIT_OP)
|
||||||
if local_init_op is None:
|
if local_init_op is None:
|
||||||
op_list = [variables.local_variables_initializer(),
|
op_list = [
|
||||||
data_flow_ops.tables_initializer()]
|
variables.local_variables_initializer(),
|
||||||
|
lookup_ops.tables_initializer()
|
||||||
|
]
|
||||||
if op_list:
|
if op_list:
|
||||||
local_init_op = control_flow_ops.group(*op_list)
|
local_init_op = control_flow_ops.group(*op_list)
|
||||||
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
|
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
|
||||||
@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
|||||||
else:
|
else:
|
||||||
session.run(variables.global_variables_initializer())
|
session.run(variables.global_variables_initializer())
|
||||||
session.run(variables.local_variables_initializer())
|
session.run(variables.local_variables_initializer())
|
||||||
session.run(data_flow_ops.tables_initializer())
|
session.run(lookup_ops.tables_initializer())
|
||||||
coord = coordinator.Coordinator()
|
coord = coordinator.Coordinator()
|
||||||
threads = None
|
threads = None
|
||||||
try:
|
try:
|
||||||
|
@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import saver as tf_saver
|
from tensorflow.python.training import saver as tf_saver
|
||||||
@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
|
|||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
with tf_session.Session('') as session:
|
with tf_session.Session('') as session:
|
||||||
variables.local_variables_initializer()
|
variables.local_variables_initializer()
|
||||||
data_flow_ops.tables_initializer()
|
lookup_ops.tables_initializer()
|
||||||
saver.restore(session, checkpoint_path)
|
saver.restore(session, checkpoint_path)
|
||||||
|
|
||||||
export = exporter.Exporter(saver)
|
export = exporter.Exporter(saver)
|
||||||
export.init(init_op=control_flow_ops.group(
|
export.init(
|
||||||
|
init_op=control_flow_ops.group(
|
||||||
variables.local_variables_initializer(),
|
variables.local_variables_initializer(),
|
||||||
data_flow_ops.tables_initializer()),
|
lookup_ops.tables_initializer()),
|
||||||
default_graph_signature=default_graph_signature,
|
default_graph_signature=default_graph_signature,
|
||||||
named_graph_signatures=named_graph_signatures,
|
named_graph_signatures=named_graph_signatures,
|
||||||
assets_collection=ops.get_collection(
|
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
|
||||||
ops.GraphKeys.ASSET_FILEPATHS))
|
|
||||||
return export.export(export_dir, contrib_variables.get_global_step(),
|
return export.export(export_dir, contrib_variables.get_global_step(),
|
||||||
session, exports_to_keep=exports_to_keep)
|
session, exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
@ -13,19 +13,10 @@ py_library(
|
|||||||
name = "lookup_py",
|
name = "lookup_py",
|
||||||
srcs = [
|
srcs = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"lookup_ops.py",
|
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python/feature_column:lookup_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
|
||||||
"//tensorflow/python:data_flow_ops_gen",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:string_ops",
|
|
||||||
"//tensorflow/python:training",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,11 +30,11 @@ py_test(
|
|||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:data_flow_ops",
|
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:lookup_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
@ -47,7 +47,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import
|
# pylint: disable=unused-import,wildcard-import
|
||||||
from tensorflow.contrib.lookup.lookup_ops import *
|
from tensorflow.python.feature_column.lookup_ops import *
|
||||||
# pylint: enable=unused-import,wildcard-import
|
# pylint: enable=unused-import,wildcard-import
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
@ -31,7 +31,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
|
|||||||
table3 = lookup.HashTable(
|
table3 = lookup.HashTable(
|
||||||
lookup.KeyValueTensorInitializer(keys, values), default_val)
|
lookup.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(3, table1.size().eval())
|
self.assertAllEqual(3, table1.size().eval())
|
||||||
self.assertAllEqual(3, table2.size().eval())
|
self.assertAllEqual(3, table2.size().eval())
|
||||||
self.assertAllEqual(3, table3.size().eval())
|
self.assertAllEqual(3, table3.size().eval())
|
||||||
@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_int32_index_table_from_file(self):
|
def test_int32_index_table_from_file(self):
|
||||||
@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_int64_index_table_from_file(self):
|
def test_int64_index_table_from_file(self):
|
||||||
@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_index_table_from_file_with_default_value(self):
|
def test_index_table_from_file_with_default_value(self):
|
||||||
@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, default_value), ids.eval())
|
self.assertAllEqual((1, 2, default_value), ids.eval())
|
||||||
|
|
||||||
def test_index_table_from_file_with_oov_buckets(self):
|
def test_index_table_from_file_with_oov_buckets(self):
|
||||||
@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
|
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
(
|
(
|
||||||
1, # From vocabulary file.
|
1, # From vocabulary file.
|
||||||
@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, -1, -1), ids.eval())
|
self.assertAllEqual((1, -1, -1), ids.eval())
|
||||||
self.assertEqual(2, table.size().eval())
|
self.assertEqual(2, table.size().eval())
|
||||||
|
|
||||||
@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, -1), ids.eval())
|
self.assertAllEqual((1, 2, -1), ids.eval())
|
||||||
self.assertEqual(3, table.size().eval())
|
self.assertEqual(3, table.size().eval())
|
||||||
|
|
||||||
@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
|
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_int32_index_table_from_tensor_with_tensor_init(self):
|
def test_int32_index_table_from_tensor_with_tensor_init(self):
|
||||||
@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase):
|
|||||||
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_int64_index_table_from_tensor_with_tensor_init(self):
|
def test_int64_index_table_from_tensor_with_tensor_init(self):
|
||||||
@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase):
|
|||||||
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_index_table_from_tensor_with_default_value(self):
|
def test_index_table_from_tensor_with_default_value(self):
|
||||||
@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase):
|
|||||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, default_value), ids.eval())
|
self.assertAllEqual((1, 2, default_value), ids.eval())
|
||||||
|
|
||||||
def test_index_table_from_tensor_missing_mapping(self):
|
def test_index_table_from_tensor_missing_mapping(self):
|
||||||
@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase):
|
|||||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors_impl.OpError, "keys and values cannot be empty"):
|
errors_impl.OpError, "keys and values cannot be empty"):
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
def test_index_table_from_tensor_with_invalid_hashers(self):
|
def test_index_table_from_tensor_with_invalid_hashers(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase):
|
|||||||
indices = lookup.string_to_index(feats, mapping=mapping_strings)
|
indices = lookup.string_to_index(feats, mapping=mapping_strings)
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, indices.eval)
|
self.assertRaises(errors_impl.OpError, indices.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllEqual((1, 2, -1), indices.eval())
|
self.assertAllEqual((1, 2, -1), indices.eval())
|
||||||
|
|
||||||
@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase):
|
|||||||
_ = lookup.string_to_index(feats, mapping=mapping_strings)
|
_ = lookup.string_to_index(feats, mapping=mapping_strings)
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError,
|
self.assertRaises(errors_impl.OpError,
|
||||||
data_flow_ops.tables_initializer().run)
|
lookup_ops.tables_initializer().run)
|
||||||
|
|
||||||
def test_string_to_index_with_default_value(self):
|
def test_string_to_index_with_default_value(self):
|
||||||
default_value = -42
|
default_value = -42
|
||||||
@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase):
|
|||||||
feats, mapping=mapping_strings, default_value=default_value)
|
feats, mapping=mapping_strings, default_value=default_value)
|
||||||
self.assertRaises(errors_impl.OpError, indices.eval)
|
self.assertRaises(errors_impl.OpError, indices.eval)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, default_value), indices.eval())
|
self.assertAllEqual((1, 2, default_value), indices.eval())
|
||||||
|
|
||||||
|
|
||||||
@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
|||||||
vocabulary_file=vocabulary_file)
|
vocabulary_file=vocabulary_file)
|
||||||
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||||
features.eval())
|
features.eval())
|
||||||
|
|
||||||
@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
|||||||
vocabulary_file=vocabulary_file, default_value=default_value)
|
vocabulary_file=vocabulary_file, default_value=default_value)
|
||||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"salad", b"surgery", default_value),
|
self.assertAllEqual((b"salad", b"surgery", default_value),
|
||||||
features.eval())
|
features.eval())
|
||||||
|
|
||||||
@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
|||||||
default_value=default_value)
|
default_value=default_value)
|
||||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"salad", default_value, default_value),
|
self.assertAllEqual((b"salad", default_value, default_value),
|
||||||
features.eval())
|
features.eval())
|
||||||
|
|
||||||
@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
|||||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
init = data_flow_ops.tables_initializer()
|
init = lookup_ops.tables_initializer()
|
||||||
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||||
"Invalid vocab_size", init.run)
|
"Invalid vocab_size", init.run)
|
||||||
|
|
||||||
@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
|||||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
|
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
|
||||||
|
|
||||||
|
|
||||||
@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
|||||||
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
||||||
features = table.lookup(indices)
|
features = table.lookup(indices)
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||||
features.eval())
|
features.eval())
|
||||||
@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
|||||||
mapping=mapping_strings)
|
mapping=mapping_strings)
|
||||||
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
||||||
features = table.lookup(indices)
|
features = table.lookup(indices)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
|
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
|
||||||
|
|
||||||
def test_index_to_string_with_default_value(self):
|
def test_index_to_string_with_default_value(self):
|
||||||
@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
|||||||
features = table.lookup(indices)
|
features = table.lookup(indices)
|
||||||
self.assertRaises(errors_impl.OpError, features.eval)
|
self.assertRaises(errors_impl.OpError, features.eval)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"salad", b"surgery", default_value),
|
self.assertAllEqual((b"salad", b"surgery", default_value),
|
||||||
features.eval())
|
features.eval())
|
||||||
|
|
||||||
@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase):
|
|||||||
feats = lookup.index_to_string(indices, mapping=mapping_strings)
|
feats = lookup.index_to_string(indices, mapping=mapping_strings)
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError, feats.eval)
|
self.assertRaises(errors_impl.OpError, feats.eval)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||||
feats.eval())
|
feats.eval())
|
||||||
@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase):
|
|||||||
mapping_strings = constant_op.constant(["hello", "hello"])
|
mapping_strings = constant_op.constant(["hello", "hello"])
|
||||||
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
||||||
feats = lookup.index_to_string(indices, mapping=mapping_strings)
|
feats = lookup.index_to_string(indices, mapping=mapping_strings)
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
|
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
|
||||||
|
|
||||||
self.assertRaises(errors_impl.OpError,
|
self.assertRaises(errors_impl.OpError,
|
||||||
data_flow_ops.tables_initializer().run)
|
lookup_ops.tables_initializer().run)
|
||||||
|
|
||||||
def test_index_to_string_with_default_value(self):
|
def test_index_to_string_with_default_value(self):
|
||||||
default_value = b"NONE"
|
default_value = b"NONE"
|
||||||
@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase):
|
|||||||
indices, mapping=mapping_strings, default_value=default_value)
|
indices, mapping=mapping_strings, default_value=default_value)
|
||||||
self.assertRaises(errors_impl.OpError, feats.eval)
|
self.assertRaises(errors_impl.OpError, feats.eval)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
|
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
|
||||||
|
|
||||||
|
|
||||||
@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
default_value,
|
default_value,
|
||||||
shared_name=shared_name)
|
shared_name=shared_name)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
input_string = constant_op.constant(["brain", "salad", "tank"])
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
||||||
|
|
||||||
@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
|
|||||||
hasher_spec=lookup.StrongHashSpec((1, 2)),
|
hasher_spec=lookup.StrongHashSpec((1, 2)),
|
||||||
name="table2")
|
name="table2")
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
input_string = constant_op.constant(
|
input_string = constant_op.constant(
|
||||||
["fruit", "brain", "salad", "surgery", "UNK"])
|
["fruit", "brain", "salad", "surgery", "UNK"])
|
||||||
@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
|
|||||||
default_value2),
|
default_value2),
|
||||||
oov_buckets)
|
oov_buckets)
|
||||||
|
|
||||||
data_flow_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
|
|
||||||
input_string_1 = constant_op.constant(
|
input_string_1 = constant_op.constant(
|
||||||
["brain", "salad", "surgery", "UNK"])
|
["brain", "salad", "surgery", "UNK"])
|
||||||
|
@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
|
|||||||
tensorflow/core/protobuf/queue_runner.pb.cc
|
tensorflow/core/protobuf/queue_runner.pb.cc
|
||||||
tensorflow/core/protobuf/named_tensor.pb.cc
|
tensorflow/core/protobuf/named_tensor.pb.cc
|
||||||
tensorflow/core/protobuf/meta_graph.pb.cc
|
tensorflow/core/protobuf/meta_graph.pb.cc
|
||||||
|
tensorflow/core/protobuf/cluster.pb.cc
|
||||||
tensorflow/core/protobuf/config.pb.cc
|
tensorflow/core/protobuf/config.pb.cc
|
||||||
tensorflow/core/protobuf/rewriter_config.pb.cc
|
tensorflow/core/protobuf/rewriter_config.pb.cc
|
||||||
tensorflow/core/protobuf/debug.pb.cc
|
tensorflow/core/protobuf/debug.pb.cc
|
||||||
|
@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
|
|||||||
tensorflow/core/protobuf/queue_runner.pb.h
|
tensorflow/core/protobuf/queue_runner.pb.h
|
||||||
tensorflow/core/protobuf/named_tensor.pb.h
|
tensorflow/core/protobuf/named_tensor.pb.h
|
||||||
tensorflow/core/protobuf/meta_graph.pb.h
|
tensorflow/core/protobuf/meta_graph.pb.h
|
||||||
|
tensorflow/core/protobuf/cluster.pb.h
|
||||||
tensorflow/core/protobuf/config.pb.h
|
tensorflow/core/protobuf/config.pb.h
|
||||||
tensorflow/core/protobuf/debug.pb.h
|
tensorflow/core/protobuf/debug.pb.h
|
||||||
tensorflow/core/protobuf/rewriter_config.pb.h
|
tensorflow/core/protobuf/rewriter_config.pb.h
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
tensorflow/core/util/saved_tensor_slice.pb_text.cc
|
tensorflow/core/util/saved_tensor_slice.pb_text.cc
|
||||||
tensorflow/core/util/memmapped_file_system.pb_text.cc
|
tensorflow/core/util/memmapped_file_system.pb_text.cc
|
||||||
tensorflow/core/protobuf/saver.pb_text.cc
|
tensorflow/core/protobuf/saver.pb_text.cc
|
||||||
|
tensorflow/core/protobuf/cluster.pb_text.cc
|
||||||
tensorflow/core/protobuf/config.pb_text.cc
|
tensorflow/core/protobuf/config.pb_text.cc
|
||||||
tensorflow/core/protobuf/debug.pb_text.cc
|
tensorflow/core/protobuf/debug.pb_text.cc
|
||||||
tensorflow/core/protobuf/rewriter_config.pb_text.cc
|
tensorflow/core/protobuf/rewriter_config.pb_text.cc
|
||||||
|
@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
|
|||||||
tensorflow/core/protobuf/queue_runner.proto
|
tensorflow/core/protobuf/queue_runner.proto
|
||||||
tensorflow/core/protobuf/named_tensor.proto
|
tensorflow/core/protobuf/named_tensor.proto
|
||||||
tensorflow/core/protobuf/meta_graph.proto
|
tensorflow/core/protobuf/meta_graph.proto
|
||||||
|
tensorflow/core/protobuf/cluster.proto
|
||||||
tensorflow/core/protobuf/config.proto
|
tensorflow/core/protobuf/config.proto
|
||||||
tensorflow/core/protobuf/debug.proto
|
tensorflow/core/protobuf/debug.proto
|
||||||
tensorflow/core/protobuf/rewriter_config.proto
|
tensorflow/core/protobuf/rewriter_config.proto
|
||||||
|
@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
|||||||
name=name_scope)
|
name=name_scope)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_recall_at_top_k(labels,
|
||||||
|
top_k_predictions,
|
||||||
|
class_id=None,
|
||||||
|
weights=None,
|
||||||
|
metrics_collections=None,
|
||||||
|
updates_collections=None,
|
||||||
|
name=None):
|
||||||
|
"""Computes recall@k of top-k predictions with respect to sparse labels.
|
||||||
|
|
||||||
|
If `class_id` is specified, we calculate recall by considering only the
|
||||||
|
entries in the batch for which `class_id` is in the label, and computing
|
||||||
|
the fraction of them for which `class_id` is in the top-k `predictions`.
|
||||||
|
If `class_id` is not specified, we'll calculate recall as how often on
|
||||||
|
average a class among the labels of a batch entry is in the top-k
|
||||||
|
`predictions`.
|
||||||
|
|
||||||
|
`sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
|
||||||
|
and `false_negative_at_<k>`, that are used to compute the recall_at_k
|
||||||
|
frequency. This frequency is ultimately returned as `recall_at_<k>`: an
|
||||||
|
idempotent operation that simply divides `true_positive_at_<k>` by total
|
||||||
|
(`true_positive_at_<k>` + `false_negative_at_<k>`).
|
||||||
|
|
||||||
|
For estimation of the metric over a stream of data, the function creates an
|
||||||
|
`update_op` operation that updates these variables and returns the
|
||||||
|
`recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
|
||||||
|
true positives and false negatives weighted by `weights`. Then `update_op`
|
||||||
|
increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
|
||||||
|
values.
|
||||||
|
|
||||||
|
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: `int64` `Tensor` or `SparseTensor` with shape
|
||||||
|
[D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
|
||||||
|
target classes for the associated prediction. Commonly, N=1 and `labels`
|
||||||
|
has shape [batch_size, num_labels]. [D1, ... DN] must match
|
||||||
|
`top_k_predictions`. Values should be in range [0, num_classes), where
|
||||||
|
num_classes is the last dimension of `predictions`. Values outside this
|
||||||
|
range always count towards `false_negative_at_<k>`.
|
||||||
|
top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
|
||||||
|
N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
|
||||||
|
The final dimension contains the indices of top-k labels. [D1, ... DN]
|
||||||
|
must match `labels`.
|
||||||
|
class_id: Integer class ID for which we want binary metrics. This should be
|
||||||
|
in range [0, num_classes), where num_classes is the last dimension of
|
||||||
|
`predictions`. If class_id is outside this range, the method returns NAN.
|
||||||
|
weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
|
||||||
|
`labels`. If the latter, it must be broadcastable to `labels` (i.e., all
|
||||||
|
dimensions must be either `1`, or the same as the corresponding `labels`
|
||||||
|
dimension).
|
||||||
|
metrics_collections: An optional list of collections that values should
|
||||||
|
be added to.
|
||||||
|
updates_collections: An optional list of collections that updates should
|
||||||
|
be added to.
|
||||||
|
name: Name of new update operation, and namespace for other dependent ops.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
|
||||||
|
by the sum of `true_positives` and `false_negatives`.
|
||||||
|
update_op: `Operation` that increments `true_positives` and
|
||||||
|
`false_negatives` variables appropriately, and whose value matches
|
||||||
|
`recall`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `weights` is not `None` and its shape doesn't match
|
||||||
|
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||||
|
are not a list or tuple.
|
||||||
|
"""
|
||||||
|
default_name = _at_k_name('recall', class_id=class_id)
|
||||||
|
with ops.name_scope(name, default_name, (top_k_predictions, labels,
|
||||||
|
weights)) as name_scope:
|
||||||
|
return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
|
||||||
|
labels=labels,
|
||||||
|
predictions_idx=top_k_predictions,
|
||||||
|
class_id=class_id,
|
||||||
|
weights=weights,
|
||||||
|
metrics_collections=metrics_collections,
|
||||||
|
updates_collections=updates_collections,
|
||||||
|
name=name_scope)
|
||||||
|
|
||||||
|
|
||||||
def streaming_sparse_average_precision_at_k(predictions,
|
def streaming_sparse_average_precision_at_k(predictions,
|
||||||
labels,
|
labels,
|
||||||
k,
|
k,
|
||||||
@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'aggregate_metric_map',
|
'aggregate_metric_map',
|
||||||
'aggregate_metrics',
|
'aggregate_metrics',
|
||||||
|
'sparse_recall_at_top_k',
|
||||||
'streaming_accuracy',
|
'streaming_accuracy',
|
||||||
'streaming_auc',
|
'streaming_auc',
|
||||||
'streaming_false_negatives',
|
'streaming_false_negatives',
|
||||||
@ -2310,7 +2392,9 @@ __all__ = [
|
|||||||
'streaming_root_mean_squared_error',
|
'streaming_root_mean_squared_error',
|
||||||
'streaming_sensitivity_at_specificity',
|
'streaming_sensitivity_at_specificity',
|
||||||
'streaming_sparse_average_precision_at_k',
|
'streaming_sparse_average_precision_at_k',
|
||||||
|
'streaming_sparse_average_precision_at_top_k',
|
||||||
'streaming_sparse_precision_at_k',
|
'streaming_sparse_precision_at_k',
|
||||||
|
'streaming_sparse_precision_at_top_k',
|
||||||
'streaming_sparse_recall_at_k',
|
'streaming_sparse_recall_at_k',
|
||||||
'streaming_specificity_at_sensitivity',
|
'streaming_specificity_at_sensitivity',
|
||||||
'streaming_true_negatives',
|
'streaming_true_negatives',
|
||||||
|
@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
self.assertEqual(expected, update.eval())
|
self.assertEqual(expected, update.eval())
|
||||||
self.assertEqual(expected, metric.eval())
|
self.assertEqual(expected, metric.eval())
|
||||||
|
|
||||||
|
def _test_sparse_recall_at_top_k(self,
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected,
|
||||||
|
class_id=None,
|
||||||
|
weights=None):
|
||||||
|
with ops.Graph().as_default() as g, self.test_session(g):
|
||||||
|
if weights is not None:
|
||||||
|
weights = constant_op.constant(weights, dtypes_lib.float32)
|
||||||
|
metric, update = metric_ops.sparse_recall_at_top_k(
|
||||||
|
labels=labels,
|
||||||
|
top_k_predictions=constant_op.constant(top_k_predictions,
|
||||||
|
dtypes_lib.int32),
|
||||||
|
class_id=class_id,
|
||||||
|
weights=weights)
|
||||||
|
|
||||||
|
# Fails without initialized vars.
|
||||||
|
self.assertRaises(errors_impl.OpError, metric.eval)
|
||||||
|
self.assertRaises(errors_impl.OpError, update.eval)
|
||||||
|
variables.variables_initializer(variables.local_variables()).run()
|
||||||
|
|
||||||
|
# Run per-step op and assert expected values.
|
||||||
|
if math.isnan(expected):
|
||||||
|
self.assertTrue(math.isnan(update.eval()))
|
||||||
|
self.assertTrue(math.isnan(metric.eval()))
|
||||||
|
else:
|
||||||
|
self.assertEqual(expected, update.eval())
|
||||||
|
self.assertEqual(expected, metric.eval())
|
||||||
|
|
||||||
def test_one_label_at_k1_nan(self):
|
def test_one_label_at_k1_nan(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
|
top_k_predictions = [[3], [3]]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
||||||
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
||||||
@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
for class_id in (-1, 0, 1, 4):
|
for class_id in (-1, 0, 1, 4):
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=NAN, class_id=class_id)
|
predictions, labels, k=1, expected=NAN, class_id=class_id)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, class_id=class_id)
|
||||||
|
|
||||||
def test_one_label_at_k1_no_predictions(self):
|
def test_one_label_at_k1_no_predictions(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
|
top_k_predictions = [[3], [3]]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
||||||
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
||||||
@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 2: 0 predictions.
|
# Class 2: 0 predictions.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=0.0, class_id=2)
|
predictions, labels, k=1, expected=0.0, class_id=2)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.0, class_id=2)
|
||||||
|
|
||||||
def test_one_label_at_k1(self):
|
def test_one_label_at_k1(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
|
top_k_predictions = [[3], [3]]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
||||||
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
||||||
@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 3: 1 label, 2 predictions, 1 correct.
|
# Class 3: 1 label, 2 predictions, 1 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 1, class_id=3)
|
predictions, labels, k=1, expected=1.0 / 1, class_id=3)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 1, class_id=3)
|
||||||
|
|
||||||
# All classes: 2 labels, 2 predictions, 1 correct.
|
# All classes: 2 labels, 2 predictions, 1 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 2)
|
predictions, labels, k=1, expected=1.0 / 2)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 2)
|
||||||
|
|
||||||
def test_one_label_at_k1_weighted(self):
|
def test_one_label_at_k1_weighted(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
|
top_k_predictions = [[3], [3]]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
[[0, 0, 0, 1], [0, 0, 1, 0]])
|
||||||
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
dense_labels = np.array([[3], [2]], dtype=np.int64)
|
||||||
@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 3: 1 label, 2 predictions, 1 correct.
|
# Class 3: 1 label, 2 predictions, 1 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
|
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 1,
|
expected=1.0 / 1,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(1.0,))
|
weights=(1.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1,
|
||||||
|
class_id=3,
|
||||||
|
weights=(1.0,))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 1,
|
expected=1.0 / 1,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(2.0,))
|
weights=(2.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1,
|
||||||
|
class_id=3,
|
||||||
|
weights=(2.0,))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=NAN,
|
expected=NAN,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(0.0, 0.0))
|
weights=(0.0, 0.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=NAN,
|
||||||
|
class_id=3,
|
||||||
|
weights=(0.0, 0.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=NAN,
|
expected=NAN,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(0.0, 1.0))
|
weights=(0.0, 1.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=NAN,
|
||||||
|
class_id=3,
|
||||||
|
weights=(0.0, 1.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 1,
|
expected=1.0 / 1,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(1.0, 0.0))
|
weights=(1.0, 0.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1,
|
||||||
|
class_id=3,
|
||||||
|
weights=(1.0, 0.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 1,
|
expected=1.0 / 1,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(1.0, 1.0))
|
weights=(1.0, 1.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1,
|
||||||
|
class_id=3,
|
||||||
|
weights=(1.0, 1.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=2.0 / 2,
|
expected=2.0 / 2,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(2.0, 3.0))
|
weights=(2.0, 3.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=2.0 / 2,
|
||||||
|
class_id=3,
|
||||||
|
weights=(2.0, 3.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=3.0 / 3,
|
expected=3.0 / 3,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(3.0, 2.0))
|
weights=(3.0, 2.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=3.0 / 3,
|
||||||
|
class_id=3,
|
||||||
|
weights=(3.0, 2.0))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=0.3 / 0.3,
|
expected=0.3 / 0.3,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(0.3, 0.6))
|
weights=(0.3, 0.6))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=0.3 / 0.3,
|
||||||
|
class_id=3,
|
||||||
|
weights=(0.3, 0.6))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=0.6 / 0.6,
|
expected=0.6 / 0.6,
|
||||||
class_id=3,
|
class_id=3,
|
||||||
weights=(0.6, 0.3))
|
weights=(0.6, 0.3))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=0.6 / 0.6,
|
||||||
|
class_id=3,
|
||||||
|
weights=(0.6, 0.3))
|
||||||
|
|
||||||
# All classes: 2 labels, 2 predictions, 1 correct.
|
# All classes: 2 labels, 2 predictions, 1 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=NAN, weights=(0.0,))
|
predictions, labels, k=1, expected=NAN, weights=(0.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, weights=(0.0,))
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
|
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
|
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
|
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
|
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
|
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
|
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
|
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
|
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
|
||||||
|
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
|
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
|
||||||
|
|
||||||
def test_three_labels_at_k5_nan(self):
|
def test_three_labels_at_k5_nan(self):
|
||||||
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
||||||
|
top_k_predictions = [
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
||||||
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
||||||
@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
for class_id in (0, 3, 4, 6, 9, 10):
|
for class_id in (0, 3, 4, 6, 9, 10):
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=NAN, class_id=class_id)
|
predictions, labels, k=5, expected=NAN, class_id=class_id)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, class_id=class_id)
|
||||||
|
|
||||||
def test_three_labels_at_k5_no_predictions(self):
|
def test_three_labels_at_k5_no_predictions(self):
|
||||||
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
||||||
|
top_k_predictions = [
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
||||||
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
||||||
@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 8: 1 label, no predictions.
|
# Class 8: 1 label, no predictions.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=0.0 / 1, class_id=8)
|
predictions, labels, k=5, expected=0.0 / 1, class_id=8)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.0 / 1, class_id=8)
|
||||||
|
|
||||||
def test_three_labels_at_k5(self):
|
def test_three_labels_at_k5(self):
|
||||||
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
||||||
|
top_k_predictions = [
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
]
|
||||||
sparse_labels = _binary_2d_label_to_sparse_value(
|
sparse_labels = _binary_2d_label_to_sparse_value(
|
||||||
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
|
||||||
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
|
||||||
@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 2: 2 labels, both correct.
|
# Class 2: 2 labels, both correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=2.0 / 2, class_id=2)
|
predictions, labels, k=5, expected=2.0 / 2, class_id=2)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=2.0 / 2, class_id=2)
|
||||||
|
|
||||||
# Class 5: 1 label, incorrect.
|
# Class 5: 1 label, incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
|
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 1, class_id=5)
|
||||||
|
|
||||||
# Class 7: 1 label, incorrect.
|
# Class 7: 1 label, incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
|
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.0 / 1, class_id=7)
|
||||||
|
|
||||||
# All classes: 6 labels, 3 correct.
|
# All classes: 6 labels, 3 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=3.0 / 6)
|
predictions, labels, k=5, expected=3.0 / 6)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=3.0 / 6)
|
||||||
|
|
||||||
def test_three_labels_at_k5_some_out_of_range(self):
|
def test_three_labels_at_k5_some_out_of_range(self):
|
||||||
"""Tests that labels outside the [0, n_classes) count in denominator."""
|
"""Tests that labels outside the [0, n_classes) count in denominator."""
|
||||||
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
|
||||||
|
top_k_predictions = [
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
]
|
||||||
sp_labels = sparse_tensor.SparseTensorValue(
|
sp_labels = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
|
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
|
||||||
[1, 3]],
|
[1, 3]],
|
||||||
@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
k=5,
|
k=5,
|
||||||
expected=2.0 / 2,
|
expected=2.0 / 2,
|
||||||
class_id=2)
|
class_id=2)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
sp_labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=2.0 / 2,
|
||||||
|
class_id=2)
|
||||||
|
|
||||||
# Class 5: 1 label, incorrect.
|
# Class 5: 1 label, incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
k=5,
|
k=5,
|
||||||
expected=1.0 / 1,
|
expected=1.0 / 1,
|
||||||
class_id=5)
|
class_id=5)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
sp_labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1,
|
||||||
|
class_id=5)
|
||||||
|
|
||||||
# Class 7: 1 label, incorrect.
|
# Class 7: 1 label, incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
k=5,
|
k=5,
|
||||||
expected=0.0 / 1,
|
expected=0.0 / 1,
|
||||||
class_id=7)
|
class_id=7)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
sp_labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=0.0 / 1,
|
||||||
|
class_id=7)
|
||||||
|
|
||||||
# All classes: 8 labels, 3 correct.
|
# All classes: 8 labels, 3 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
|
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
sp_labels, top_k_predictions, expected=3.0 / 8)
|
||||||
|
|
||||||
def test_3d_nan(self):
|
def test_3d_nan(self):
|
||||||
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
||||||
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
||||||
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
||||||
|
top_k_predictions = [[
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
], [
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
]]
|
||||||
sparse_labels = _binary_3d_label_to_sparse_value(
|
sparse_labels = _binary_3d_label_to_sparse_value(
|
||||||
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
||||||
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
|
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
|
||||||
@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
for class_id in (0, 3, 4, 6, 9, 10):
|
for class_id in (0, 3, 4, 6, 9, 10):
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=NAN, class_id=class_id)
|
predictions, labels, k=5, expected=NAN, class_id=class_id)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, class_id=class_id)
|
||||||
|
|
||||||
def test_3d_no_predictions(self):
|
def test_3d_no_predictions(self):
|
||||||
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
||||||
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
||||||
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
||||||
|
top_k_predictions = [[
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
], [
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
]]
|
||||||
sparse_labels = _binary_3d_label_to_sparse_value(
|
sparse_labels = _binary_3d_label_to_sparse_value(
|
||||||
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
||||||
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
||||||
@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
for class_id in (1, 8):
|
for class_id in (1, 8):
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=0.0, class_id=class_id)
|
predictions, labels, k=5, expected=0.0, class_id=class_id)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=0.0, class_id=class_id)
|
||||||
|
|
||||||
def test_3d(self):
|
def test_3d(self):
|
||||||
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
||||||
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
||||||
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
||||||
|
top_k_predictions = [[
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
], [
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
]]
|
||||||
labels = _binary_3d_label_to_sparse_value(
|
labels = _binary_3d_label_to_sparse_value(
|
||||||
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
||||||
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
||||||
@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
# Class 2: 4 labels, all correct.
|
# Class 2: 4 labels, all correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
|
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=4.0 / 4, class_id=2)
|
||||||
|
|
||||||
# Class 5: 2 labels, both correct.
|
# Class 5: 2 labels, both correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
|
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=2.0 / 2, class_id=5)
|
||||||
|
|
||||||
# Class 7: 2 labels, 1 incorrect.
|
# Class 7: 2 labels, 1 incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
|
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=1.0 / 2, class_id=7)
|
||||||
|
|
||||||
# All classes: 12 labels, 7 correct.
|
# All classes: 12 labels, 7 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=7.0 / 12)
|
predictions, labels, k=5, expected=7.0 / 12)
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=7.0 / 12)
|
||||||
|
|
||||||
def test_3d_ignore_all(self):
|
def test_3d_ignore_all(self):
|
||||||
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
||||||
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
||||||
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
||||||
|
top_k_predictions = [[
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
], [
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
]]
|
||||||
labels = _binary_3d_label_to_sparse_value(
|
labels = _binary_3d_label_to_sparse_value(
|
||||||
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
||||||
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
||||||
@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=NAN,
|
expected=NAN,
|
||||||
class_id=class_id,
|
class_id=class_id,
|
||||||
weights=[[0], [0]])
|
weights=[[0], [0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=NAN,
|
||||||
|
class_id=class_id,
|
||||||
|
weights=[[0], [0]])
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions,
|
predictions,
|
||||||
labels,
|
labels,
|
||||||
@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=NAN,
|
expected=NAN,
|
||||||
class_id=class_id,
|
class_id=class_id,
|
||||||
weights=[[0, 0], [0, 0]])
|
weights=[[0, 0], [0, 0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=NAN,
|
||||||
|
class_id=class_id,
|
||||||
|
weights=[[0, 0], [0, 0]])
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
|
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, weights=[[0], [0]])
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
|
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]])
|
||||||
|
|
||||||
def test_3d_ignore_some(self):
|
def test_3d_ignore_some(self):
|
||||||
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
|
||||||
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
|
||||||
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
|
||||||
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
|
||||||
|
top_k_predictions = [[
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
], [
|
||||||
|
[5, 7, 2, 9, 6],
|
||||||
|
[9, 4, 6, 2, 0],
|
||||||
|
]]
|
||||||
labels = _binary_3d_label_to_sparse_value(
|
labels = _binary_3d_label_to_sparse_value(
|
||||||
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
|
||||||
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
|
||||||
@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=2.0 / 2.0,
|
expected=2.0 / 2.0,
|
||||||
class_id=2,
|
class_id=2,
|
||||||
weights=[[1], [0]])
|
weights=[[1], [0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=2.0 / 2.0,
|
||||||
|
class_id=2,
|
||||||
|
weights=[[1], [0]])
|
||||||
|
|
||||||
# Class 2: 2 labels, both correct.
|
# Class 2: 2 labels, both correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=2.0 / 2.0,
|
expected=2.0 / 2.0,
|
||||||
class_id=2,
|
class_id=2,
|
||||||
weights=[[0], [1]])
|
weights=[[0], [1]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=2.0 / 2.0,
|
||||||
|
class_id=2,
|
||||||
|
weights=[[0], [1]])
|
||||||
|
|
||||||
# Class 7: 1 label, correct.
|
# Class 7: 1 label, correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 1.0,
|
expected=1.0 / 1.0,
|
||||||
class_id=7,
|
class_id=7,
|
||||||
weights=[[0], [1]])
|
weights=[[0], [1]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 1.0,
|
||||||
|
class_id=7,
|
||||||
|
weights=[[0], [1]])
|
||||||
|
|
||||||
# Class 7: 1 label, incorrect.
|
# Class 7: 1 label, incorrect.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=0.0 / 1.0,
|
expected=0.0 / 1.0,
|
||||||
class_id=7,
|
class_id=7,
|
||||||
weights=[[1], [0]])
|
weights=[[1], [0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=0.0 / 1.0,
|
||||||
|
class_id=7,
|
||||||
|
weights=[[1], [0]])
|
||||||
|
|
||||||
# Class 7: 2 labels, 1 correct.
|
# Class 7: 2 labels, 1 correct.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=1.0 / 2.0,
|
expected=1.0 / 2.0,
|
||||||
class_id=7,
|
class_id=7,
|
||||||
weights=[[1, 0], [1, 0]])
|
weights=[[1, 0], [1, 0]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=1.0 / 2.0,
|
||||||
|
class_id=7,
|
||||||
|
weights=[[1, 0], [1, 0]])
|
||||||
|
|
||||||
# Class 7: No labels.
|
# Class 7: No labels.
|
||||||
self._test_streaming_sparse_recall_at_k(
|
self._test_streaming_sparse_recall_at_k(
|
||||||
@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase):
|
|||||||
expected=NAN,
|
expected=NAN,
|
||||||
class_id=7,
|
class_id=7,
|
||||||
weights=[[0, 1], [0, 1]])
|
weights=[[0, 1], [0, 1]])
|
||||||
|
self._test_sparse_recall_at_top_k(
|
||||||
|
labels,
|
||||||
|
top_k_predictions,
|
||||||
|
expected=NAN,
|
||||||
|
class_id=7,
|
||||||
|
weights=[[0, 1], [0, 1]])
|
||||||
|
|
||||||
def test_sparse_tensor_value(self):
|
def test_sparse_tensor_value(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4],
|
predictions = [[0.1, 0.3, 0.2, 0.4],
|
||||||
|
@ -304,6 +304,7 @@ filegroup(
|
|||||||
exclude = [
|
exclude = [
|
||||||
"**/METADATA",
|
"**/METADATA",
|
||||||
"**/OWNERS",
|
"**/OWNERS",
|
||||||
|
"tools/**",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -351,3 +352,27 @@ tf_kernel_library(
|
|||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "checkpoint_convert",
|
||||||
|
srcs = ["python/tools/checkpoint_convert.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "checkpoint_convert_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/tools/checkpoint_convert_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["no_pip"],
|
||||||
|
deps = [
|
||||||
|
":checkpoint_convert",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m)
|
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
||||||
|
g, _ = cell(x, m)
|
||||||
|
self.assertEqual(
|
||||||
|
["root/basic_rnn_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
|
"root/basic_rnn_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
||||||
|
[v.name for v in cell.trainable_variables])
|
||||||
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
|
res = sess.run(
|
||||||
|
[g], {x.name: np.array([[1., 1.]]),
|
||||||
|
m.name: np.array([[0.1, 0.1]])})
|
||||||
|
self.assertEqual(res[0].shape, (1, 2))
|
||||||
|
|
||||||
|
def testBasicRNNCellNotTrainable(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
def not_trainable_getter(getter, *args, **kwargs):
|
||||||
|
kwargs["trainable"] = False
|
||||||
|
return getter(*args, **kwargs)
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root", initializer=init_ops.constant_initializer(0.5),
|
||||||
|
custom_getter=not_trainable_getter):
|
||||||
|
x = array_ops.zeros([1, 2])
|
||||||
|
m = array_ops.zeros([1, 2])
|
||||||
|
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
||||||
|
g, _ = cell(x, m)
|
||||||
|
self.assertFalse(cell.trainable_variables)
|
||||||
|
self.assertEqual(
|
||||||
|
["root/basic_rnn_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
|
"root/basic_rnn_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
||||||
|
[v.name for v in cell.non_trainable_variables])
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
[g], {x.name: np.array([[1., 1.]]),
|
||||||
@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 8])
|
m = array_ops.zeros([1, 8])
|
||||||
g, out_m = core_rnn_cell_impl.MultiRNNCell(
|
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(
|
[core_rnn_cell_impl.BasicLSTMCell(
|
||||||
2, state_is_tuple=False) for _ in range(2)],
|
2, state_is_tuple=False) for _ in range(2)],
|
||||||
state_is_tuple=False)(x, m)
|
state_is_tuple=False)
|
||||||
|
g, out_m = cell(x, m)
|
||||||
|
expected_variable_names = [
|
||||||
|
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
|
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._BIAS_VARIABLE_NAME,
|
||||||
|
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
|
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
||||||
|
% core_rnn_cell_impl._BIAS_VARIABLE_NAME]
|
||||||
|
self.assertEqual(
|
||||||
|
expected_variable_names, [v.name for v in cell.trainable_variables])
|
||||||
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g, out_m],
|
[g, out_m],
|
||||||
@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
m.name: 0.1 * np.ones([1, 8])})
|
m.name: 0.1 * np.ones([1, 8])})
|
||||||
self.assertEqual(len(res), 2)
|
self.assertEqual(len(res), 2)
|
||||||
variables = variables_lib.global_variables()
|
variables = variables_lib.global_variables()
|
||||||
self.assertEqual(4, len(variables))
|
self.assertEqual(expected_variable_names, [v.name for v in variables])
|
||||||
self.assertEquals(variables[0].op.name,
|
|
||||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
|
|
||||||
self.assertEquals(variables[1].op.name,
|
|
||||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
|
|
||||||
self.assertEquals(variables[2].op.name,
|
|
||||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
|
|
||||||
self.assertEquals(variables[3].op.name,
|
|
||||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
|
|
||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
|
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
|
||||||
expected_mem = np.array([[
|
expected_mem = np.array([[
|
||||||
|
@ -27,7 +27,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import math
|
import math
|
||||||
import numbers
|
import numbers
|
||||||
@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases"
|
|||||||
_WEIGHTS_VARIABLE_NAME = "weights"
|
_WEIGHTS_VARIABLE_NAME = "weights"
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _checked_scope(cell, scope, reuse=None, **kwargs):
|
|
||||||
if reuse is not None:
|
|
||||||
kwargs["reuse"] = reuse
|
|
||||||
with vs.variable_scope(scope, **kwargs) as checking_scope:
|
|
||||||
scope_name = checking_scope.name
|
|
||||||
if hasattr(cell, "_scope"):
|
|
||||||
cell_scope = cell._scope # pylint: disable=protected-access
|
|
||||||
if cell_scope.name != checking_scope.name:
|
|
||||||
raise ValueError(
|
|
||||||
"Attempt to reuse RNNCell %s with a different variable scope than "
|
|
||||||
"its first use. First use of cell was with scope '%s', this "
|
|
||||||
"attempt is with scope '%s'. Please create a new instance of the "
|
|
||||||
"cell if you would like it to use a different set of weights. "
|
|
||||||
"If before you were using: MultiRNNCell([%s(...)] * num_layers), "
|
|
||||||
"change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). "
|
|
||||||
"If before you were using the same cell instance as both the "
|
|
||||||
"forward and reverse cell of a bidirectional RNN, simply create "
|
|
||||||
"two instances (one for forward, one for reverse). "
|
|
||||||
"In May 2017, we will start transitioning this cell's behavior "
|
|
||||||
"to use existing stored weights, if any, when it is called "
|
|
||||||
"with scope=None (which can lead to silent model degradation, so "
|
|
||||||
"this error will remain until then.)"
|
|
||||||
% (cell, cell_scope.name, scope_name, type(cell).__name__,
|
|
||||||
type(cell).__name__))
|
|
||||||
else:
|
|
||||||
weights_found = False
|
|
||||||
try:
|
|
||||||
with vs.variable_scope(checking_scope, reuse=True):
|
|
||||||
vs.get_variable(_WEIGHTS_VARIABLE_NAME)
|
|
||||||
weights_found = True
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
if weights_found and reuse is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Attempt to have a second RNNCell use the weights of a variable "
|
|
||||||
"scope that already has weights: '%s'; and the cell was not "
|
|
||||||
"constructed as %s(..., reuse=True). "
|
|
||||||
"To share the weights of an RNNCell, simply "
|
|
||||||
"reuse it in your second calculation, or create a new one with "
|
|
||||||
"the argument reuse=True." % (scope_name, type(cell).__name__))
|
|
||||||
|
|
||||||
# Everything is OK. Update the cell's scope and yield it.
|
|
||||||
cell._scope = checking_scope # pylint: disable=protected-access
|
|
||||||
yield checking_scope
|
|
||||||
|
|
||||||
|
|
||||||
class BasicRNNCell(RNNCell):
|
class BasicRNNCell(RNNCell):
|
||||||
"""The most basic RNN cell."""
|
"""The most basic RNN cell."""
|
||||||
|
|
||||||
|
@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def _get_concat_variable(name, shape, dtype, num_shards):
|
def _get_concat_variable(name, shape, dtype, num_shards):
|
||||||
"""Get a sharded variable concatenated into one tensor."""
|
"""Get a sharded variable concatenated into one tensor."""
|
||||||
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
|
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
|
||||||
|
231
tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
Normal file
231
tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
r"""Convert checkpoints using RNNCells to new name convention.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python checkpoint_convert [--write_v1_checkpoint] \
|
||||||
|
'/path/to/checkpoint' '/path/to/new_checkpoint'
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import collections
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import app
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
|
||||||
|
_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
|
||||||
|
############################################################################
|
||||||
|
# contrib/rnn/python/ops/core_rnn_cell_impl.py
|
||||||
|
# BasicRNNCell
|
||||||
|
('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
|
||||||
|
('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
|
||||||
|
# GRUCell
|
||||||
|
('gru_cell/weights', 'gru_cell/kernel'),
|
||||||
|
('gru_cell/biases', 'gru_cell/bias'),
|
||||||
|
('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
|
||||||
|
('gru_cell/gates/biases', 'gru_cell/gates/bias'),
|
||||||
|
('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
|
||||||
|
('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
|
||||||
|
# BasicLSTMCell
|
||||||
|
('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
|
||||||
|
('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
|
||||||
|
# LSTMCell
|
||||||
|
('lstm_cell/weights', 'lstm_cell/kernel'),
|
||||||
|
('lstm_cell/biases', 'lstm_cell/bias'),
|
||||||
|
('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
|
||||||
|
('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
|
||||||
|
# OutputProjectionWrapper
|
||||||
|
('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
|
||||||
|
('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
|
||||||
|
# InputProjectionWrapper
|
||||||
|
('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
|
||||||
|
('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
|
||||||
|
############################################################################
|
||||||
|
# contrib/rnn/python/ops/lstm_ops.py
|
||||||
|
# LSTMBlockFusedCell ??
|
||||||
|
('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
|
||||||
|
('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
|
||||||
|
############################################################################
|
||||||
|
# contrib/rnn/python/ops/rnn_cell.py
|
||||||
|
# LayerNormBasicLSTMCell
|
||||||
|
('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
|
||||||
|
('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
|
||||||
|
# UGRNNCell, not found in g3, but still need it?
|
||||||
|
('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
|
||||||
|
('ugrnn_cell/biases', 'ugrnn_cell/bias'),
|
||||||
|
# NASCell
|
||||||
|
('nas_rnn/weights', 'nas_rnn/kernel'),
|
||||||
|
('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
|
||||||
|
# IntersectionRNNCell
|
||||||
|
('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
|
||||||
|
('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
|
||||||
|
('intersection_rnn_cell/in_projection/weights',
|
||||||
|
'intersection_rnn_cell/in_projection/kernel'),
|
||||||
|
('intersection_rnn_cell/in_projection/biases',
|
||||||
|
'intersection_rnn_cell/in_projection/bias'),
|
||||||
|
# PhasedLSTMCell
|
||||||
|
('phased_lstm_cell/mask_gates/weights',
|
||||||
|
'phased_lstm_cell/mask_gates/kernel'),
|
||||||
|
('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
|
||||||
|
('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
|
||||||
|
('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
|
||||||
|
('phased_lstm_cell/output_gate/weights',
|
||||||
|
'phased_lstm_cell/output_gate/kernel'),
|
||||||
|
('phased_lstm_cell/output_gate/biases',
|
||||||
|
'phased_lstm_cell/output_gate/bias'),
|
||||||
|
# AttentionCellWrapper
|
||||||
|
('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
|
||||||
|
('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
|
||||||
|
('attention_cell_wrapper/attn_output_projection/weights',
|
||||||
|
'attention_cell_wrapper/attn_output_projection/kernel'),
|
||||||
|
('attention_cell_wrapper/attn_output_projection/biases',
|
||||||
|
'attention_cell_wrapper/attn_output_projection/bias'),
|
||||||
|
('attention_cell_wrapper/attention/weights',
|
||||||
|
'attention_cell_wrapper/attention/kernel'),
|
||||||
|
('attention_cell_wrapper/attention/biases',
|
||||||
|
'attention_cell_wrapper/attention/bias'),
|
||||||
|
])
|
||||||
|
|
||||||
|
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
|
||||||
|
('LSTMCell/W_', 'lstm_cell/weights/part_'),
|
||||||
|
('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
|
||||||
|
('GRUCell/W_', 'gru_cell/weights/part_'),
|
||||||
|
('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _rnn_name_replacement(var_name):
|
||||||
|
for pattern in _RNN_NAME_REPLACEMENTS:
|
||||||
|
if pattern in var_name:
|
||||||
|
old_var_name = var_name
|
||||||
|
var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
|
||||||
|
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
|
||||||
|
break
|
||||||
|
return var_name
|
||||||
|
|
||||||
|
|
||||||
|
def _rnn_name_replacement_sharded(var_name):
|
||||||
|
for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
|
||||||
|
if pattern in var_name:
|
||||||
|
old_var_name = var_name
|
||||||
|
var_name = var_name.replace(pattern,
|
||||||
|
_RNN_SHARDED_NAME_REPLACEMENTS[pattern])
|
||||||
|
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
|
||||||
|
return var_name
|
||||||
|
|
||||||
|
|
||||||
|
def _split_sharded_vars(name_shape_map):
|
||||||
|
"""Split shareded variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_shape_map: A dict from variable name to variable shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
not_sharded: Names of the non-sharded variables.
|
||||||
|
sharded: Names of the sharded varibales.
|
||||||
|
"""
|
||||||
|
sharded = []
|
||||||
|
not_sharded = []
|
||||||
|
for name in name_shape_map:
|
||||||
|
if re.match(name, '_[0-9]+$'):
|
||||||
|
if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
|
||||||
|
sharded.append(name)
|
||||||
|
else:
|
||||||
|
not_sharded.append(name)
|
||||||
|
else:
|
||||||
|
not_sharded.append(name)
|
||||||
|
return not_sharded, sharded
|
||||||
|
|
||||||
|
|
||||||
|
def convert_names(checkpoint_from_path,
|
||||||
|
checkpoint_to_path,
|
||||||
|
write_v1_checkpoint=False):
|
||||||
|
"""Migrates the names of variables within a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_from_path: Path to source checkpoint to be read in.
|
||||||
|
checkpoint_to_path: Path to checkpoint to be written out.
|
||||||
|
write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary that maps the new variable names to the Variable objects.
|
||||||
|
A dictionary that maps the old variable names to the new variable names.
|
||||||
|
"""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
|
||||||
|
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
|
||||||
|
name_shape_map = reader.get_variable_to_shape_map()
|
||||||
|
not_sharded, sharded = _split_sharded_vars(name_shape_map)
|
||||||
|
new_variable_map = {}
|
||||||
|
conversion_map = {}
|
||||||
|
for var_name in not_sharded:
|
||||||
|
new_var_name = _rnn_name_replacement(var_name)
|
||||||
|
tensor = reader.get_tensor(var_name)
|
||||||
|
var = variables.Variable(tensor, name=var_name)
|
||||||
|
new_variable_map[new_var_name] = var
|
||||||
|
if new_var_name != var_name:
|
||||||
|
conversion_map[var_name] = new_var_name
|
||||||
|
for var_name in sharded:
|
||||||
|
new_var_name = _rnn_name_replacement_sharded(var_name)
|
||||||
|
var = variables.Variable(tensor, name=var_name)
|
||||||
|
new_variable_map[new_var_name] = var
|
||||||
|
if new_var_name != var_name:
|
||||||
|
conversion_map[var_name] = new_var_name
|
||||||
|
|
||||||
|
write_version = (saver_pb2.SaverDef.V1
|
||||||
|
if write_v1_checkpoint else saver_pb2.SaverDef.V2)
|
||||||
|
saver = saver_lib.Saver(new_variable_map, write_version=write_version)
|
||||||
|
|
||||||
|
with session.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
|
||||||
|
saver.save(sess, checkpoint_to_path)
|
||||||
|
|
||||||
|
logging.info('Summary:')
|
||||||
|
logging.info(' Converted %d variable name(s).' % len(new_variable_map))
|
||||||
|
return new_variable_map, conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
convert_names(
|
||||||
|
FLAGS.checkpoint_from_path,
|
||||||
|
FLAGS.checkpoint_to_path,
|
||||||
|
write_v1_checkpoint=FLAGS.write_v1_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||||
|
parser.add_argument('checkpoint_from_path', type=str,
|
||||||
|
help='Path to source checkpoint to be read in.')
|
||||||
|
parser.add_argument('checkpoint_to_path', type=str,
|
||||||
|
help='Path to checkpoint to be written out.')
|
||||||
|
parser.add_argument('--write_v1_checkpoint', action='store_true',
|
||||||
|
help='Write v1 checkpoint')
|
||||||
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
|
||||||
|
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
108
tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
Normal file
108
tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Unit tests for checkpoint converter."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from tensorflow.contrib.rnn.python.tools import checkpoint_convert
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointConvertTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._old_ckpt_path = tempfile.mktemp()
|
||||||
|
self._new_ckpt_path = tempfile.mktemp()
|
||||||
|
ops.reset_default_graph()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
for file_name in glob.glob(self._old_ckpt_path + "*"):
|
||||||
|
os.remove(file_name)
|
||||||
|
for file_name in glob.glob(self._new_ckpt_path + "*"):
|
||||||
|
os.remove(file_name)
|
||||||
|
|
||||||
|
def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
|
||||||
|
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
|
||||||
|
new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
|
||||||
|
self.assertTrue(old_name)
|
||||||
|
self.assertTrue(new_name)
|
||||||
|
self.assertNotEqual(old_name, new_name)
|
||||||
|
for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
|
||||||
|
new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
|
||||||
|
self.assertTrue(old_name)
|
||||||
|
self.assertTrue(new_name)
|
||||||
|
self.assertNotEqual(old_name, new_name)
|
||||||
|
|
||||||
|
def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
|
||||||
|
variables.Variable(10.0, name="a")
|
||||||
|
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
|
||||||
|
variables.Variable(20.0, name=old_name)
|
||||||
|
with session.Session() as sess:
|
||||||
|
saver = saver_lib.Saver()
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
saver.save(sess, self._old_ckpt_path)
|
||||||
|
|
||||||
|
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||||
|
self._old_ckpt_path, self._new_ckpt_path)
|
||||||
|
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
|
||||||
|
new_var_map.keys())
|
||||||
|
self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
|
||||||
|
|
||||||
|
def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
|
||||||
|
variables.Variable(10.0, name="a")
|
||||||
|
with session.Session() as sess:
|
||||||
|
saver = saver_lib.Saver()
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
saver.save(sess, self._old_ckpt_path)
|
||||||
|
|
||||||
|
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||||
|
self._old_ckpt_path, self._new_ckpt_path)
|
||||||
|
self.assertItemsEqual(["a"], new_var_map.keys())
|
||||||
|
self.assertFalse(conversion_map)
|
||||||
|
|
||||||
|
def testConversionToV1Succeeds(self):
|
||||||
|
variables.Variable(10.0, name="a")
|
||||||
|
variables.Variable(
|
||||||
|
20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
|
||||||
|
|
||||||
|
with session.Session() as sess:
|
||||||
|
saver = saver_lib.Saver()
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
saver.save(sess, self._old_ckpt_path)
|
||||||
|
|
||||||
|
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||||
|
self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
|
||||||
|
new_var_map.keys())
|
||||||
|
self.assertEqual(
|
||||||
|
{list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
|
||||||
|
list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
|
||||||
|
conversion_map)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -261,7 +261,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -657,7 +657,7 @@ def train(train_op,
|
|||||||
if local_init_op == _USE_DEFAULT:
|
if local_init_op == _USE_DEFAULT:
|
||||||
local_init_op = control_flow_ops.group(
|
local_init_op = control_flow_ops.group(
|
||||||
tf_variables.local_variables_initializer(),
|
tf_variables.local_variables_initializer(),
|
||||||
data_flow_ops.tables_initializer())
|
lookup_ops.tables_initializer())
|
||||||
|
|
||||||
if sync_optimizer is not None and isinstance(
|
if sync_optimizer is not None and isinstance(
|
||||||
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
|
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||||
|
@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
|
|||||||
"framework/versions.proto",
|
"framework/versions.proto",
|
||||||
"lib/core/error_codes.proto",
|
"lib/core/error_codes.proto",
|
||||||
"protobuf/config.proto",
|
"protobuf/config.proto",
|
||||||
|
"protobuf/cluster.proto",
|
||||||
"protobuf/debug.proto",
|
"protobuf/debug.proto",
|
||||||
"protobuf/queue_runner.proto",
|
"protobuf/queue_runner.proto",
|
||||||
"protobuf/rewriter_config.proto",
|
"protobuf/rewriter_config.proto",
|
||||||
@ -506,6 +507,7 @@ tf_gen_op_libs(
|
|||||||
"image_ops",
|
"image_ops",
|
||||||
"io_ops",
|
"io_ops",
|
||||||
"linalg_ops",
|
"linalg_ops",
|
||||||
|
"lookup_ops",
|
||||||
"logging_ops",
|
"logging_ops",
|
||||||
"math_ops",
|
"math_ops",
|
||||||
"nn_ops",
|
"nn_ops",
|
||||||
@ -582,6 +584,7 @@ cc_library(
|
|||||||
":image_ops_op_lib",
|
":image_ops_op_lib",
|
||||||
":io_ops_op_lib",
|
":io_ops_op_lib",
|
||||||
":linalg_ops_op_lib",
|
":linalg_ops_op_lib",
|
||||||
|
":lookup_ops_op_lib",
|
||||||
":logging_ops_op_lib",
|
":logging_ops_op_lib",
|
||||||
":math_ops_op_lib",
|
":math_ops_op_lib",
|
||||||
":nn_ops_op_lib",
|
":nn_ops_op_lib",
|
||||||
@ -708,6 +711,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:image",
|
"//tensorflow/core/kernels:image",
|
||||||
"//tensorflow/core/kernels:io",
|
"//tensorflow/core/kernels:io",
|
||||||
"//tensorflow/core/kernels:linalg",
|
"//tensorflow/core/kernels:linalg",
|
||||||
|
"//tensorflow/core/kernels:lookup",
|
||||||
"//tensorflow/core/kernels:logging",
|
"//tensorflow/core/kernels:logging",
|
||||||
"//tensorflow/core/kernels:math",
|
"//tensorflow/core/kernels:math",
|
||||||
"//tensorflow/core/kernels:multinomial_op",
|
"//tensorflow/core/kernels:multinomial_op",
|
||||||
|
@ -23,8 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Device::Device(Env* env, const DeviceAttributes& device_attributes,
|
Device::Device(Env* env, const DeviceAttributes& device_attributes)
|
||||||
Allocator* device_allocator)
|
|
||||||
: DeviceBase(env), device_attributes_(device_attributes) {
|
: DeviceBase(env), device_attributes_(device_attributes) {
|
||||||
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
|
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
|
||||||
<< "Invalid device name: " << name();
|
<< "Invalid device name: " << name();
|
||||||
|
@ -53,8 +53,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
class Device : public DeviceBase {
|
class Device : public DeviceBase {
|
||||||
public:
|
public:
|
||||||
Device(Env* env, const DeviceAttributes& device_attributes,
|
Device(Env* env, const DeviceAttributes& device_attributes);
|
||||||
Allocator* device_allocator);
|
|
||||||
~Device() override;
|
~Device() override;
|
||||||
|
|
||||||
// Full name of this device (see top comment).
|
// Full name of this device (see top comment).
|
||||||
|
@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
|||||||
for (Device* d : devices) {
|
for (Device* d : devices) {
|
||||||
devices_.push_back(d);
|
devices_.push_back(d);
|
||||||
|
|
||||||
// Register under both the full name and the local name.
|
// Register under the (1) full name, (2) canonical name, and (3) local name.
|
||||||
string full_name = d->name();
|
string full_name = d->name();
|
||||||
device_map_[CopyToBackingStore(full_name)] = d;
|
device_map_[CopyToBackingStore(full_name)] = d;
|
||||||
|
|
||||||
|
DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
|
||||||
|
if (parsed_name.has_job && parsed_name.has_replica &&
|
||||||
|
parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
|
||||||
|
string canonical_name = DeviceNameUtils::FullName(
|
||||||
|
parsed_name.job, parsed_name.replica, parsed_name.task,
|
||||||
|
parsed_name.type, parsed_name.id);
|
||||||
|
device_map_[CopyToBackingStore(canonical_name)] = d;
|
||||||
|
}
|
||||||
string lname = DeviceNameUtils::LocalName(d->name());
|
string lname = DeviceNameUtils::LocalName(d->name());
|
||||||
device_map_[CopyToBackingStore(lname)] = d;
|
device_map_[CopyToBackingStore(lname)] = d;
|
||||||
device_type_counts_[d->device_type()]++;
|
device_type_counts_[d->device_type()]++;
|
||||||
@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
|||||||
}
|
}
|
||||||
|
|
||||||
DeviceMgr::~DeviceMgr() {
|
DeviceMgr::~DeviceMgr() {
|
||||||
for (auto p : devices_) delete p;
|
// TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
|
||||||
|
for (Device* p : devices_) delete p;
|
||||||
}
|
}
|
||||||
|
|
||||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||||
@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
|||||||
Status s;
|
Status s;
|
||||||
auto iter = device_map_.find(name);
|
auto iter = device_map_.find(name);
|
||||||
if (iter == device_map_.end()) {
|
if (iter == device_map_.end()) {
|
||||||
|
std::vector<StringPiece> device_names;
|
||||||
|
for (auto&& itr : device_map_) {
|
||||||
|
device_names.push_back(itr.first);
|
||||||
|
}
|
||||||
|
LOG(WARNING) << "Unknown device: " << name
|
||||||
|
<< " all devices: " << str_util::Join(device_names, ", ");
|
||||||
return errors::InvalidArgument(name, " unknown device.");
|
return errors::InvalidArgument(name, " unknown device.");
|
||||||
}
|
}
|
||||||
*device = iter->second;
|
*device = iter->second;
|
||||||
|
@ -36,6 +36,7 @@ class DeviceMgr {
|
|||||||
public:
|
public:
|
||||||
// Takes ownership of each device in 'devices'.
|
// Takes ownership of each device in 'devices'.
|
||||||
// TODO(zhifengc): Other initialization information.
|
// TODO(zhifengc): Other initialization information.
|
||||||
|
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
|
||||||
explicit DeviceMgr(const std::vector<Device*>& devices);
|
explicit DeviceMgr(const std::vector<Device*>& devices);
|
||||||
~DeviceMgr();
|
~DeviceMgr();
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class DeviceMgr {
|
|||||||
int NumDeviceType(const string& type) const;
|
int NumDeviceType(const string& type) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
|
||||||
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
|
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
|
||||||
DeviceVec devices_;
|
DeviceVec devices_;
|
||||||
|
|
||||||
|
@ -39,7 +39,10 @@ class DeviceSet {
|
|||||||
|
|
||||||
// Set the device designated as the "client". This device
|
// Set the device designated as the "client". This device
|
||||||
// must also be registered via AddDevice().
|
// must also be registered via AddDevice().
|
||||||
void set_client_device(Device* device) { client_device_ = device; }
|
void set_client_device(Device* device) {
|
||||||
|
DCHECK(client_device_ == nullptr);
|
||||||
|
client_device_ = device;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a pointer to the device designated as the "client".
|
// Returns a pointer to the device designated as the "client".
|
||||||
Device* client_device() const { return client_device_; }
|
Device* client_device() const { return client_device_; }
|
||||||
|
@ -27,8 +27,7 @@ namespace {
|
|||||||
static Device* Dev(const char* type, const char* name) {
|
static Device* Dev(const char* type, const char* name) {
|
||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
explicit FakeDevice(const DeviceAttributes& attr)
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
: Device(nullptr, attr, nullptr) {}
|
|
||||||
Status Sync() override { return Status::OK(); }
|
Status Sync() override { return Status::OK(); }
|
||||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||||
};
|
};
|
||||||
|
@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
|
|||||||
int gpu_id, const string& physical_device_desc,
|
int gpu_id, const string& physical_device_desc,
|
||||||
Allocator* gpu_allocator, Allocator* cpu_allocator,
|
Allocator* gpu_allocator, Allocator* cpu_allocator,
|
||||||
bool sync_every_op, int32 max_streams)
|
bool sync_every_op, int32 max_streams)
|
||||||
: LocalDevice(options,
|
: LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
|
||||||
Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit,
|
memory_limit, locality,
|
||||||
locality, physical_device_desc),
|
physical_device_desc)),
|
||||||
gpu_allocator),
|
|
||||||
gpu_allocator_(gpu_allocator),
|
gpu_allocator_(gpu_allocator),
|
||||||
cpu_allocator_(cpu_allocator),
|
cpu_allocator_(cpu_allocator),
|
||||||
gpu_id_(gpu_id),
|
gpu_id_(gpu_id),
|
||||||
|
@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
LocalDevice::LocalDevice(const SessionOptions& options,
|
LocalDevice::LocalDevice(const SessionOptions& options,
|
||||||
const DeviceAttributes& attributes,
|
const DeviceAttributes& attributes)
|
||||||
Allocator* device_allocator)
|
: Device(options.env, attributes), owned_tp_info_(nullptr) {
|
||||||
: Device(options.env, attributes, device_allocator),
|
|
||||||
owned_tp_info_(nullptr) {
|
|
||||||
// If we're running on the CPU, log warnings if we're not compiled using the
|
// If we're running on the CPU, log warnings if we're not compiled using the
|
||||||
// best flags for performance.
|
// best flags for performance.
|
||||||
port::WarnAboutUnusedCPUFeatures();
|
port::WarnAboutUnusedCPUFeatures();
|
||||||
|
@ -33,8 +33,8 @@ struct SessionOptions;
|
|||||||
// GPUDevice into more 'process-wide' abstractions.
|
// GPUDevice into more 'process-wide' abstractions.
|
||||||
class LocalDevice : public Device {
|
class LocalDevice : public Device {
|
||||||
public:
|
public:
|
||||||
LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
|
LocalDevice(const SessionOptions& options,
|
||||||
Allocator* device_allocator);
|
const DeviceAttributes& attributes);
|
||||||
~LocalDevice() override;
|
~LocalDevice() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
54
tensorflow/core/common_runtime/renamed_device.cc
Normal file
54
tensorflow/core/common_runtime/renamed_device.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// TODO(saeta): Convert to returning a std::unique_ptr?
|
||||||
|
/* static */
|
||||||
|
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
|
||||||
|
Device* underlying,
|
||||||
|
bool owns_underlying) {
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
|
||||||
|
DeviceNameUtils::ParsedName underlying_parsed_name =
|
||||||
|
underlying->parsed_name();
|
||||||
|
CHECK(underlying_parsed_name.has_type);
|
||||||
|
CHECK(underlying_parsed_name.has_id);
|
||||||
|
parsed_name.type = underlying_parsed_name.type;
|
||||||
|
parsed_name.id = underlying_parsed_name.id;
|
||||||
|
string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
|
||||||
|
parsed_name.task, parsed_name.type,
|
||||||
|
parsed_name.id);
|
||||||
|
DeviceAttributes attributes(underlying->attributes());
|
||||||
|
attributes.set_name(name);
|
||||||
|
return new RenamedDevice(underlying, attributes, owns_underlying);
|
||||||
|
}
|
||||||
|
|
||||||
|
RenamedDevice::RenamedDevice(Device* underlying,
|
||||||
|
const DeviceAttributes& attributes,
|
||||||
|
bool owns_underlying)
|
||||||
|
: Device(underlying->env(), attributes),
|
||||||
|
underlying_(underlying),
|
||||||
|
owns_underlying_(owns_underlying) {}
|
||||||
|
|
||||||
|
RenamedDevice::~RenamedDevice() {
|
||||||
|
if (owns_underlying_) {
|
||||||
|
delete underlying_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
119
tensorflow/core/common_runtime/renamed_device.h
Normal file
119
tensorflow/core/common_runtime/renamed_device.h
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Wraps a device with a new name, delegating work to the wrapped device.
|
||||||
|
//
|
||||||
|
// This class is used to wrap local devices when using clusterspec propagation
|
||||||
|
// where the name of a particular device may change in the context of a given
|
||||||
|
// session.
|
||||||
|
class RenamedDevice : public Device {
|
||||||
|
public:
|
||||||
|
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
|
||||||
|
bool owns_underlying);
|
||||||
|
~RenamedDevice() override;
|
||||||
|
|
||||||
|
// Below are virtual methods defined on DeviceBase
|
||||||
|
bool RequiresRecordingAccessedTensors() const override {
|
||||||
|
return underlying_->RequiresRecordingAccessedTensors();
|
||||||
|
}
|
||||||
|
|
||||||
|
const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
|
||||||
|
return underlying_->tensorflow_cpu_worker_threads();
|
||||||
|
}
|
||||||
|
|
||||||
|
const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
|
||||||
|
return underlying_->tensorflow_gpu_device_info();
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator* GetAllocator(AllocatorAttributes attr) override {
|
||||||
|
return underlying_->GetAllocator(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator* GetStepAllocator(AllocatorAttributes attr,
|
||||||
|
ResourceMgr* step_resource_manager) override {
|
||||||
|
return underlying_->GetStepAllocator(attr, step_resource_manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
|
||||||
|
return underlying_->eigen_cpu_device();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
const Eigen::SyclDevice* eigen_sycl_device() const override {
|
||||||
|
return underlying_->eigen_sycl_device();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
PerOpGpuDevice* MakeGpuDevice() override {
|
||||||
|
return underlying_->MakeGpuDevice();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
|
||||||
|
DeviceContext* dc, Allocator* allocator) override {
|
||||||
|
underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||||
|
const AllocatorAttributes alloc_attrs,
|
||||||
|
Tensor* tensor) override {
|
||||||
|
return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below are virtual methods defined on Device
|
||||||
|
|
||||||
|
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
|
||||||
|
underlying_->Compute(op_kernel, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||||
|
AsyncOpKernel::DoneCallback done) override {
|
||||||
|
underlying_->ComputeAsync(op_kernel, context, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConsumeListOfAccessedTensors(
|
||||||
|
DeviceContext* context, const TensorReferenceVector& tensors) override {
|
||||||
|
underlying_->ConsumeListOfAccessedTensors(context, tensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Sync() override { return underlying_->Sync(); }
|
||||||
|
|
||||||
|
Status MaybeRewriteGraph(const FunctionDefLibrary& library,
|
||||||
|
std::unique_ptr<Graph>* graph) override {
|
||||||
|
return underlying_->MaybeRewriteGraph(library, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FillContextMap(const Graph* graph,
|
||||||
|
DeviceContextMap* device_context_map) override {
|
||||||
|
return underlying_->FillContextMap(graph, device_context_map);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
|
||||||
|
bool owns_underlying);
|
||||||
|
Device* const underlying_;
|
||||||
|
const bool owns_underlying_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
|
@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
|
|||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
private:
|
private:
|
||||||
explicit FakeDevice(const DeviceAttributes& device_attributes)
|
explicit FakeDevice(const DeviceAttributes& device_attributes)
|
||||||
: Device(nullptr, device_attributes, nullptr) {}
|
: Device(nullptr, device_attributes) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
|
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
|
||||||
|
@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
const string& name, Bytes memory_limit,
|
const string& name, Bytes memory_limit,
|
||||||
const DeviceLocality& locality,
|
const DeviceLocality& locality,
|
||||||
Allocator* allocator)
|
Allocator* allocator)
|
||||||
: LocalDevice(options,
|
: LocalDevice(options, Device::BuildDeviceAttributes(
|
||||||
Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit,
|
name, DEVICE_CPU, memory_limit, locality)),
|
||||||
locality),
|
|
||||||
allocator),
|
|
||||||
allocator_(allocator) {}
|
allocator_(allocator) {}
|
||||||
|
|
||||||
ThreadPoolDevice::~ThreadPoolDevice() {}
|
ThreadPoolDevice::~ThreadPoolDevice() {}
|
||||||
|
@ -77,7 +77,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":graph_mgr",
|
":graph_mgr",
|
||||||
":rendezvous_mgr_interface",
|
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -92,9 +91,9 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":graph_mgr",
|
":graph_mgr",
|
||||||
":worker_session",
|
":worker_session",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,6 +236,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -35,9 +35,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env,
|
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
|
||||||
const string& worker_name)
|
: worker_env_(worker_env) {}
|
||||||
: worker_env_(worker_env), worker_name_(worker_name) {}
|
|
||||||
|
|
||||||
BaseRendezvousMgr::~BaseRendezvousMgr() {
|
BaseRendezvousMgr::~BaseRendezvousMgr() {
|
||||||
for (auto& p : table_) {
|
for (auto& p : table_) {
|
||||||
@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
|
RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
|
||||||
return FindOrCreate(step_id);
|
return FindOrCreate(step_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
Table::iterator iter = table_.find(step_id);
|
Table::iterator iter = table_.find(step_id);
|
||||||
if (iter == table_.end()) {
|
if (iter == table_.end()) {
|
||||||
auto rr = Create(step_id, worker_env_, worker_name_);
|
auto rr = Create(step_id, worker_env_);
|
||||||
iter = table_.insert({step_id, rr}).first;
|
iter = table_.insert({step_id, rr}).first;
|
||||||
}
|
}
|
||||||
iter->second->Ref();
|
iter->second->Ref();
|
||||||
@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
|
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||||
const string& worker_name,
|
|
||||||
int64 step_id,
|
|
||||||
bool tolerate_dup_recv)
|
bool tolerate_dup_recv)
|
||||||
: env_(env),
|
: env_(env),
|
||||||
worker_name_(worker_name),
|
|
||||||
step_id_(step_id),
|
step_id_(step_id),
|
||||||
local_(NewLocalRendezvous(tolerate_dup_recv)) {}
|
local_(NewLocalRendezvous(tolerate_dup_recv)),
|
||||||
|
session_(nullptr) {}
|
||||||
|
|
||||||
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
|
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
|
||||||
CHECK(active_.empty());
|
CHECK(active_.empty());
|
||||||
@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
|
|||||||
return device_name.starts_with(worker_name);
|
return device_name.starts_with(worker_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
|
||||||
|
CHECK_NE(session, nullptr) << "session must not be null!";
|
||||||
|
std::vector<DeferredCall> deferred_calls;
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (session_ != nullptr) {
|
||||||
|
if (session_->worker_name == session->worker_name) {
|
||||||
|
LOG(INFO) << "Skipping rendezvous re-initialization.";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status s = errors::Internal(
|
||||||
|
"Double init! Worker names would have changed from: ",
|
||||||
|
session_->worker_name, " -> ", session->worker_name);
|
||||||
|
LOG(WARNING) << s;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
session_ = session;
|
||||||
|
std::swap(deferred_calls, deferred_calls_);
|
||||||
|
}
|
||||||
|
for (DeferredCall& call : deferred_calls) {
|
||||||
|
RecvLocalAsyncInternal(call.parsed, std::move(call.done));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
WorkerSession* BaseRemoteRendezvous::session() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return session_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BaseRemoteRendezvous::is_initialized() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return is_initialized_locked();
|
||||||
|
}
|
||||||
|
|
||||||
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
||||||
const Rendezvous::Args& args,
|
const Rendezvous::Args& args,
|
||||||
const Tensor& val, const bool is_dead) {
|
const Tensor& val, const bool is_dead) {
|
||||||
@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
|||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!status_.ok()) return status_;
|
if (!status_.ok()) return status_;
|
||||||
|
DCHECK(is_initialized_locked());
|
||||||
|
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
|
||||||
|
session_->worker_name);
|
||||||
}
|
}
|
||||||
if (!IsLocalDevice(worker_name_, parsed.src_device)) {
|
|
||||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
|
||||||
parsed.FullKey(), " @ ", worker_name_);
|
|
||||||
}
|
}
|
||||||
// Buffers "val" and "device_context" in local_.
|
// Buffers "val" and "device_context" in local_.
|
||||||
return local_->Send(parsed, args, val, is_dead);
|
return local_->Send(parsed, args, val, is_dead);
|
||||||
@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
|||||||
|
|
||||||
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
|
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
|
||||||
bool is_src) {
|
bool is_src) {
|
||||||
|
// Cache session pointer to avoid repeatedly taking & releasing the lock
|
||||||
|
// (e.g. calling session())
|
||||||
|
WorkerSession* sess = nullptr;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!status_.ok()) return status_;
|
if (!status_.ok()) return status_;
|
||||||
|
if (!is_initialized_locked()) {
|
||||||
|
return errors::Internal("ValidateDevices called before initialization.");
|
||||||
}
|
}
|
||||||
if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) {
|
sess = session_;
|
||||||
|
}
|
||||||
|
if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
|
||||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
||||||
parsed.FullKey(), " @ ", worker_name_);
|
parsed.FullKey(), " @ ", sess->worker_name);
|
||||||
}
|
}
|
||||||
if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) {
|
if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
|
||||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
|
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
|
||||||
parsed.FullKey(), " @ ", worker_name_);
|
parsed.FullKey(), " @ ", sess->worker_name);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
|
|||||||
const Rendezvous::Args& recv_args,
|
const Rendezvous::Args& recv_args,
|
||||||
DoneCallback done) {
|
DoneCallback done) {
|
||||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
|
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
|
||||||
|
CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
|
||||||
Status s = ValidateDevices(parsed, false /*!is_src*/);
|
Status s = ValidateDevices(parsed, false /*!is_src*/);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s, Args(), recv_args, Tensor(), false);
|
done(s, Args(), recv_args, Tensor(), false);
|
||||||
@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
|
|||||||
|
|
||||||
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
|
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
|
||||||
DoneCallback done) {
|
DoneCallback done) {
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (!is_initialized_locked()) {
|
||||||
|
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
|
||||||
|
// remote worker) before the RunStep (or PartialRunStep) RPC from the
|
||||||
|
// master arrives. RecvLocalAsync thus buffers the arguments until after
|
||||||
|
// the RemoteRendezvous is Initialize()'d, when it completes the
|
||||||
|
// rendezvous logic. At some point after Initialize() is called, a Tensor
|
||||||
|
// is produced locally that will then be sent in response to the incoming
|
||||||
|
// RPC.
|
||||||
|
DeferredCall call(parsed, std::move(done));
|
||||||
|
deferred_calls_.push_back(call);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RecvLocalAsyncInternal(parsed, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
|
void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
|
||||||
|
DoneCallback done) {
|
||||||
Status s = ValidateDevices(parsed, true /* is_src */);
|
Status s = ValidateDevices(parsed, true /* is_src */);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s, Args(), Args(), Tensor(), false);
|
done(s, Args(), Args(), Tensor(), false);
|
||||||
@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
|
|||||||
active_.erase(call);
|
active_.erase(call);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
|
||||||
|
DoneCallback done)
|
||||||
|
: parsed(parsed), done(std::move(done)) {}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -59,15 +59,17 @@ class BaseRecvTensorCall;
|
|||||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
|
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
|
||||||
class BaseRendezvousMgr : public RendezvousMgrInterface {
|
class BaseRendezvousMgr : public RendezvousMgrInterface {
|
||||||
public:
|
public:
|
||||||
explicit BaseRendezvousMgr(const WorkerEnv* worker_env,
|
explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
|
||||||
const string& worker_name);
|
|
||||||
|
|
||||||
~BaseRendezvousMgr() override;
|
~BaseRendezvousMgr() override;
|
||||||
|
|
||||||
// Returns Rendezvous supporting send and recv among workers in the
|
// Returns Rendezvous supporting send and recv among workers in the
|
||||||
// "step_id". The caller takes ownership of one reference on the
|
// "step_id". The caller takes ownership of one reference on the
|
||||||
// returned Rendezvous instance.
|
// returned Rendezvous instance.
|
||||||
Rendezvous* Find(int64 step_id) override;
|
//
|
||||||
|
// Note: the caller must guarantee to eventually call Initialize on the
|
||||||
|
// returned RemoteRendezvous
|
||||||
|
RemoteRendezvous* Find(int64 step_id) override;
|
||||||
|
|
||||||
// Finds the local rendezvous instance for the "step_id". Runs
|
// Finds the local rendezvous instance for the "step_id". Runs
|
||||||
// "done" when the tensor for "key" is produced or an error occurs.
|
// "done" when the tensor for "key" is produced or an error occurs.
|
||||||
@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual BaseRemoteRendezvous* Create(int64 step_id,
|
virtual BaseRemoteRendezvous* Create(int64 step_id,
|
||||||
const WorkerEnv* worker_env,
|
const WorkerEnv* worker_env) = 0;
|
||||||
const string& worker_name) = 0;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Maps step_id to rendezvous.
|
// Maps step_id to rendezvous.
|
||||||
@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
|
|||||||
|
|
||||||
// Not owned.
|
// Not owned.
|
||||||
const WorkerEnv* const worker_env_;
|
const WorkerEnv* const worker_env_;
|
||||||
const string worker_name_;
|
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
Table table_ GUARDED_BY(mu_);
|
Table table_ GUARDED_BY(mu_);
|
||||||
@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
|
|||||||
// Buffering of Tensor values is delegated to a "local" Rendezvous
|
// Buffering of Tensor values is delegated to a "local" Rendezvous
|
||||||
// obtained from NewLocalRendezvous(). This class just adds
|
// obtained from NewLocalRendezvous(). This class just adds
|
||||||
// functionality to coordinate with remote workers.
|
// functionality to coordinate with remote workers.
|
||||||
class BaseRemoteRendezvous : public Rendezvous {
|
class BaseRemoteRendezvous : public RemoteRendezvous {
|
||||||
public:
|
public:
|
||||||
BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
|
BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||||
int64 step_id, bool tolerate_dup_recv);
|
bool tolerate_dup_recv);
|
||||||
|
|
||||||
|
// Upgrades the BaseRemoteRendezvous to full initialization.
|
||||||
|
Status Initialize(WorkerSession* session) override;
|
||||||
|
|
||||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||||
// any waiting callback stored.
|
// any waiting callback stored.
|
||||||
@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
|
|||||||
// Removes "call" from active_ if "call" is in active_.
|
// Removes "call" from active_ if "call" is in active_.
|
||||||
void DeregisterCall(BaseRecvTensorCall* call);
|
void DeregisterCall(BaseRecvTensorCall* call);
|
||||||
|
|
||||||
|
WorkerSession* session();
|
||||||
|
|
||||||
|
bool is_initialized();
|
||||||
|
|
||||||
~BaseRemoteRendezvous() override;
|
~BaseRemoteRendezvous() override;
|
||||||
|
|
||||||
const WorkerEnv* const env_; // Not owned.
|
const WorkerEnv* const env_; // Not owned.
|
||||||
const string worker_name_;
|
|
||||||
const int64 step_id_;
|
const int64 step_id_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
|
|||||||
|
|
||||||
// Status given by StartAbort() if any.
|
// Status given by StartAbort() if any.
|
||||||
Status status_ GUARDED_BY(mu_);
|
Status status_ GUARDED_BY(mu_);
|
||||||
|
WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
|
||||||
|
|
||||||
|
// Data structures to handle calls when partially initialized.
|
||||||
|
struct DeferredCall {
|
||||||
|
const ParsedKey parsed;
|
||||||
|
DoneCallback done;
|
||||||
|
|
||||||
|
DeferredCall(const ParsedKey& parsed, DoneCallback done);
|
||||||
|
};
|
||||||
|
std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// Active outstanding RecvTensor calls.
|
// Active outstanding RecvTensor calls.
|
||||||
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
|
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
return session_ != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// If "is_src" is true, checks that the rendezvous key "parsed"'s
|
// If "is_src" is true, checks that the rendezvous key "parsed"'s
|
||||||
// source is in this process. If "is_src" is false, checks that the
|
// source is in this process. If "is_src" is false, checks that the
|
||||||
// rendezvous key "parsed"'s destination is in this process.
|
// rendezvous key "parsed"'s destination is in this process.
|
||||||
@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
|
|||||||
const Rendezvous::Args& out_args, const Tensor& in,
|
const Rendezvous::Args& out_args, const Tensor& in,
|
||||||
Tensor* out, StatusCallback done);
|
Tensor* out, StatusCallback done);
|
||||||
|
|
||||||
|
// Must be called only if fully initialized.
|
||||||
|
void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
|
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -46,10 +46,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
GraphMgr::GraphMgr(const WorkerEnv* worker_env,
|
GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
|
||||||
RendezvousMgrInterface* rendezvous_mgr)
|
: worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
|
||||||
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
|
|
||||||
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
|
|
||||||
// The default value of sync_on_finish will be flipped soon and this
|
// The default value of sync_on_finish will be flipped soon and this
|
||||||
// environment variable will be removed as well.
|
// environment variable will be removed as well.
|
||||||
Status status =
|
Status status =
|
||||||
@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
|||||||
};
|
};
|
||||||
popts.get_incarnation = [this](const string& name) -> int64 {
|
popts.get_incarnation = [this](const string& name) -> int64 {
|
||||||
Device* device = nullptr;
|
Device* device = nullptr;
|
||||||
Status s = worker_env_->device_mgr->LookupDevice(name, &device);
|
Status s = device_mgr_->LookupDevice(name, &device);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
return device->attributes().incarnation();
|
return device->attributes().incarnation();
|
||||||
} else {
|
} else {
|
||||||
@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
|||||||
ExecutionUnit* unit = &(item->units.back());
|
ExecutionUnit* unit = &(item->units.back());
|
||||||
|
|
||||||
// Find the device.
|
// Find the device.
|
||||||
Status s =
|
Status s = device_mgr_->LookupDevice(device_name, &unit->device);
|
||||||
worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
// Remove the empty unit from the item as the item destructor wants all
|
// Remove the empty unit from the item as the item destructor wants all
|
||||||
// units to have valid devices.
|
// units to have valid devices.
|
||||||
@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
|||||||
|
|
||||||
// Function library runtime.
|
// Function library runtime.
|
||||||
unit->lib = NewFunctionLibraryRuntime(
|
unit->lib = NewFunctionLibraryRuntime(
|
||||||
worker_env_->device_mgr, worker_env_->env, unit->device,
|
device_mgr_, worker_env_->env, unit->device,
|
||||||
subgraph->versions().producer(), item->lib_def,
|
subgraph->versions().producer(), item->lib_def,
|
||||||
graph_options.optimizer_options());
|
graph_options.optimizer_options());
|
||||||
|
|
||||||
@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
||||||
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = SendInputsToRendezvous(rendezvous, in);
|
Status s = SendInputsToRendezvous(rendezvous, in);
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||||
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = RecvOutputsFromRendezvous(rendezvous, out);
|
Status s = RecvOutputsFromRendezvous(rendezvous, out);
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
|||||||
|
|
||||||
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
RecvOutputsFromRendezvousAsync(rendezvous, out,
|
RecvOutputsFromRendezvousAsync(rendezvous, out,
|
||||||
[done, rendezvous](const Status s) {
|
[done, rendezvous](const Status s) {
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||||
const ExecutorOpts& opts,
|
WorkerSession* session,
|
||||||
|
const ExecutorOpts& /*opts*/,
|
||||||
StepStatsCollector* collector,
|
StepStatsCollector* collector,
|
||||||
CostGraphDef* cost_graph,
|
CostGraphDef* cost_graph,
|
||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
|
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
|
Status s = rendezvous->Initialize(session);
|
||||||
|
|
||||||
// Sends values specified by the caller.
|
// Sends values specified by the caller.
|
||||||
Status s = SendInputsToRendezvous(rendezvous, in);
|
if (s.ok()) {
|
||||||
|
s = SendInputsToRendezvous(rendezvous, in);
|
||||||
|
}
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
done(s);
|
||||||
item->Unref();
|
item->Unref();
|
||||||
@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
|
|||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
const int num_units = item->units.size();
|
const int num_units = item->units.size();
|
||||||
CHECK_GE(num_units, 1);
|
CHECK_GE(num_units, 1);
|
||||||
ScopedStepContainer* step_container =
|
ScopedStepContainer* step_container = new ScopedStepContainer(
|
||||||
new ScopedStepContainer(step_id, [this](const string& name) {
|
step_id,
|
||||||
worker_env_->device_mgr->ClearContainers({name});
|
[this](const string& name) { device_mgr_->ClearContainers({name}); });
|
||||||
});
|
|
||||||
// NOTE: Transfer one ref of rendezvous and item.
|
// NOTE: Transfer one ref of rendezvous and item.
|
||||||
ExecutorBarrier* barrier =
|
ExecutorBarrier* barrier =
|
||||||
new ExecutorBarrier(num_units, rendezvous,
|
new ExecutorBarrier(num_units, rendezvous,
|
||||||
|
@ -37,6 +37,8 @@ namespace tensorflow {
|
|||||||
class ExecutorOpts;
|
class ExecutorOpts;
|
||||||
class StepStatsCollector;
|
class StepStatsCollector;
|
||||||
class RendezvousMgrInterface;
|
class RendezvousMgrInterface;
|
||||||
|
class DeviceMgr;
|
||||||
|
struct WorkerSession;
|
||||||
|
|
||||||
// GraphMgr keeps track of a set of graphs that are registered with a
|
// GraphMgr keeps track of a set of graphs that are registered with a
|
||||||
// TensorFlow worker. Each registered graph is identified by a handle
|
// TensorFlow worker. Each registered graph is identified by a handle
|
||||||
@ -62,8 +64,7 @@ class RendezvousMgrInterface;
|
|||||||
// EXPECT_EQ(out["c"], Tensor({4, 6}));
|
// EXPECT_EQ(out["c"], Tensor({4, 6}));
|
||||||
class GraphMgr {
|
class GraphMgr {
|
||||||
public:
|
public:
|
||||||
explicit GraphMgr(const WorkerEnv* worker_env,
|
explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
|
||||||
RendezvousMgrInterface* rendezvous_mgr);
|
|
||||||
~GraphMgr();
|
~GraphMgr();
|
||||||
|
|
||||||
// Registers a graph. Fills in "handle"
|
// Registers a graph. Fills in "handle"
|
||||||
@ -78,8 +79,8 @@ class GraphMgr {
|
|||||||
typedef std::map<string, Tensor> NamedTensors;
|
typedef std::map<string, Tensor> NamedTensors;
|
||||||
typedef std::function<void(const Status&)> StatusCallback;
|
typedef std::function<void(const Status&)> StatusCallback;
|
||||||
void ExecuteAsync(const string& handle, const int64 step_id,
|
void ExecuteAsync(const string& handle, const int64 step_id,
|
||||||
const ExecutorOpts& opts, StepStatsCollector* collector,
|
WorkerSession* session, const ExecutorOpts& opts,
|
||||||
CostGraphDef* cost_graph,
|
StepStatsCollector* collector, CostGraphDef* cost_graph,
|
||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
const NamedTensors& in, StatusCallback done);
|
const NamedTensors& in, StatusCallback done);
|
||||||
|
|
||||||
@ -131,7 +132,7 @@ class GraphMgr {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const WorkerEnv* worker_env_; // Not owned.
|
const WorkerEnv* worker_env_; // Not owned.
|
||||||
RendezvousMgrInterface* rendezvous_mgr_; // Not owned.
|
DeviceMgr* device_mgr_;
|
||||||
|
|
||||||
CostModelManager cost_model_manager_;
|
CostModelManager cost_model_manager_;
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
@ -48,12 +49,17 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
#include "tensorflow/core/protobuf/master.pb.h"
|
#include "tensorflow/core/protobuf/master.pb.h"
|
||||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const char* const kGrpcProtocol = "grpc://";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Master::Master(MasterEnv* env, double session_gc_seconds)
|
Master::Master(MasterEnv* env, double session_gc_seconds)
|
||||||
: env_(env),
|
: env_(env),
|
||||||
last_1000_steps_(1000),
|
last_1000_steps_(1000),
|
||||||
@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
|
|||||||
CreateSessionResponse* resp, MyClosure done) {
|
CreateSessionResponse* resp, MyClosure done) {
|
||||||
SchedClosure([this, req, resp, done]() {
|
SchedClosure([this, req, resp, done]() {
|
||||||
Status status;
|
Status status;
|
||||||
|
WorkerCacheFactoryOptions worker_cache_factory_options;
|
||||||
|
string grpc_protocol("grpc");
|
||||||
|
worker_cache_factory_options.protocol = &grpc_protocol;
|
||||||
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
|
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
|
||||||
status = ValidateExternalGraphDefSyntax(req->graph_def());
|
status = ValidateExternalGraphDefSyntax(req->graph_def());
|
||||||
if (!status.ok()) return;
|
if (!status.ok()) return;
|
||||||
// Ping all the workers and build the list of devices that the
|
|
||||||
// session will use.
|
// The following 4 variables are set differently, depending on whether this
|
||||||
|
// session uses a client-provided clusterspec or not.
|
||||||
|
WorkerCacheInterface* worker_cache = nullptr;
|
||||||
|
// Note: worker_cache_ptr will be null except if this session is using a
|
||||||
|
// client-supplied ClusterDef (ClusterSpec propagation).
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
|
||||||
|
std::unique_ptr<DeviceSet> device_set;
|
||||||
// TODO(saeta): Convert to std::make_unique when available.
|
// TODO(saeta): Convert to std::make_unique when available.
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
|
||||||
new std::vector<std::unique_ptr<Device>>());
|
new std::vector<std::unique_ptr<Device>>());
|
||||||
status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
|
|
||||||
env_, env_->worker_cache,
|
if (req->config().has_cluster_def()) {
|
||||||
remote_devices.get());
|
worker_cache_factory_options.cluster_def = &req->config().cluster_def();
|
||||||
|
|
||||||
|
// Set the server_def's job_name and task_index fields.
|
||||||
|
string normalized_string;
|
||||||
|
string grpc_protocol(kGrpcProtocol);
|
||||||
|
if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
|
||||||
|
0) {
|
||||||
|
normalized_string =
|
||||||
|
req->target().substr(grpc_protocol.length(), string::npos);
|
||||||
|
} else {
|
||||||
|
normalized_string = req->target();
|
||||||
|
}
|
||||||
|
for (auto&& job : req->config().cluster_def().job()) {
|
||||||
|
for (auto&& task : job.tasks()) {
|
||||||
|
if (task.second == normalized_string) {
|
||||||
|
if (worker_cache_factory_options.job_name != nullptr) {
|
||||||
|
status = errors::InvalidArgument(
|
||||||
|
"Found multiple matching tasks that correspond to "
|
||||||
|
"to the master. Master target: '",
|
||||||
|
req->target(), "'. ClusterDef: ",
|
||||||
|
req->config().cluster_def().ShortDebugString());
|
||||||
|
LOG(ERROR) << status;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (env_->local_devices[0]->parsed_name().job == job.name() &&
|
||||||
|
env_->local_devices[0]->parsed_name().task == task.first) {
|
||||||
|
// TODO(b/37868888): Remove this limitation when resolved
|
||||||
|
status = errors::InvalidArgument(
|
||||||
|
"The ClusterSpec names the job and task index to be the same "
|
||||||
|
"names that were provided when the server booted. This is "
|
||||||
|
"currently not allowed. Job: ",
|
||||||
|
job.name(), ", task index: ", task.first);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
worker_cache_factory_options.job_name = &job.name();
|
||||||
|
worker_cache_factory_options.task_index = task.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the worker cache from the computed server_def.
|
||||||
|
status = env_->worker_cache_factory(worker_cache_factory_options,
|
||||||
|
&worker_cache);
|
||||||
if (!status.ok()) return;
|
if (!status.ok()) return;
|
||||||
|
worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
|
||||||
|
// Ping all the workers and build the list of devices that the
|
||||||
|
// session will use.
|
||||||
|
status =
|
||||||
|
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
|
||||||
|
worker_cache, remote_devices.get());
|
||||||
|
if (!status.ok()) return;
|
||||||
|
device_set.reset(new DeviceSet);
|
||||||
|
for (auto&& d : *remote_devices) {
|
||||||
|
device_set->AddDevice(d.get());
|
||||||
|
DeviceNameUtils::ParsedName name = d->parsed_name();
|
||||||
|
if (name.job == *worker_cache_factory_options.job_name &&
|
||||||
|
name.task == worker_cache_factory_options.task_index &&
|
||||||
|
name.type == "CPU") {
|
||||||
|
device_set->set_client_device(d.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
worker_cache = env_->worker_cache;
|
||||||
|
// Ping all the workers and build the list of devices that the
|
||||||
|
// session will use.
|
||||||
|
status =
|
||||||
|
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
|
||||||
|
worker_cache, remote_devices.get());
|
||||||
|
if (!status.ok()) return;
|
||||||
|
device_set.reset(new DeviceSet);
|
||||||
|
for (auto&& d : *remote_devices) {
|
||||||
|
device_set->AddDevice(d.get());
|
||||||
|
}
|
||||||
|
int num_local_devices = 0;
|
||||||
|
for (Device* d : env_->local_devices) {
|
||||||
|
device_set->AddDevice(d);
|
||||||
|
if (num_local_devices == 0) {
|
||||||
|
// Uses the first local device as the client device.
|
||||||
|
device_set->set_client_device(d);
|
||||||
|
}
|
||||||
|
num_local_devices++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK(device_set->client_device());
|
||||||
|
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
options.config = req->config();
|
options.config = req->config();
|
||||||
MasterSession* session =
|
|
||||||
env_->master_session_factory(options, env_, std::move(remote_devices));
|
MasterSession* session = env_->master_session_factory(
|
||||||
|
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
|
||||||
|
std::move(device_set));
|
||||||
|
|
||||||
GraphDef* gdef =
|
GraphDef* gdef =
|
||||||
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
|
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
|
||||||
status = session->Create(gdef);
|
|
||||||
|
status = session->Create(gdef, worker_cache_factory_options);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
session->Close().IgnoreError();
|
session->Close().IgnoreError();
|
||||||
session->Unref();
|
session->Unref();
|
||||||
|
@ -19,17 +19,41 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
class DeviceSet;
|
||||||
class Env;
|
class Env;
|
||||||
class MasterSession;
|
class MasterSession;
|
||||||
class OpRegistryInterface;
|
class OpRegistryInterface;
|
||||||
class WorkerCacheInterface;
|
class WorkerCacheInterface;
|
||||||
|
|
||||||
|
// Options passed to the worker_cache_factory function.
|
||||||
|
struct WorkerCacheFactoryOptions {
|
||||||
|
const ClusterDef* cluster_def = nullptr;
|
||||||
|
const string* job_name = nullptr;
|
||||||
|
int task_index;
|
||||||
|
const string* protocol = nullptr;
|
||||||
|
|
||||||
|
WorkerCacheFactoryOptions() {}
|
||||||
|
|
||||||
|
// Construct from a ServerDef proto.
|
||||||
|
//
|
||||||
|
// Note: server_def must outlive WorkerCacheFactoryOptions!
|
||||||
|
WorkerCacheFactoryOptions(const ServerDef& server_def) {
|
||||||
|
if (server_def.has_cluster() && !server_def.job_name().empty()) {
|
||||||
|
cluster_def = &server_def.cluster();
|
||||||
|
job_name = &server_def.job_name();
|
||||||
|
task_index = server_def.task_index();
|
||||||
|
protocol = &server_def.protocol();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// The master environment class, which holds a bag of pointers to
|
// The master environment class, which holds a bag of pointers to
|
||||||
// per-master state.
|
// per-master state.
|
||||||
//
|
//
|
||||||
@ -57,8 +81,14 @@ struct MasterEnv {
|
|||||||
// `MasterEnv*` is retained by the caller.
|
// `MasterEnv*` is retained by the caller.
|
||||||
std::function<MasterSession*(
|
std::function<MasterSession*(
|
||||||
SessionOptions, MasterEnv*,
|
SessionOptions, MasterEnv*,
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>>)>
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
|
||||||
|
std::unique_ptr<WorkerCacheInterface>,
|
||||||
|
std::unique_ptr<DeviceSet> device_set)>
|
||||||
master_session_factory;
|
master_session_factory;
|
||||||
|
|
||||||
|
std::function<Status(const WorkerCacheFactoryOptions&,
|
||||||
|
WorkerCacheInterface**)>
|
||||||
|
worker_cache_factory;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -36,11 +36,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -162,7 +164,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
// Partitions the graph into subgraphs and registers them on
|
// Partitions the graph into subgraphs and registers them on
|
||||||
// workers.
|
// workers.
|
||||||
Status RegisterPartitions(const PartitionOptions& popts,
|
Status RegisterPartitions(const PartitionOptions& popts,
|
||||||
const FunctionDefLibrary& func_def_lib);
|
const FunctionLibraryDefinition& flib_def);
|
||||||
|
|
||||||
// Runs one step of all partitions.
|
// Runs one step of all partitions.
|
||||||
Status RunPartitions(const MasterEnv* env, int64 step_id,
|
Status RunPartitions(const MasterEnv* env, int64 step_id,
|
||||||
@ -273,7 +275,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
||||||
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) {
|
const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
|
||||||
{ // Ensure register once.
|
{ // Ensure register once.
|
||||||
mu_.lock();
|
mu_.lock();
|
||||||
if (!init_started_) {
|
if (!init_started_) {
|
||||||
@ -292,7 +294,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
|||||||
graph_defs_for_publishing.push_back(&name_def.second);
|
graph_defs_for_publishing.push_back(&name_def.second);
|
||||||
}
|
}
|
||||||
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
|
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
|
||||||
s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs));
|
s = DoRegisterPartitions(popts, flib_def.ToProto(),
|
||||||
|
std::move(graph_defs));
|
||||||
}
|
}
|
||||||
mu_.lock();
|
mu_.lock();
|
||||||
init_result_ = s;
|
init_result_ = s;
|
||||||
@ -527,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
|||||||
c->req->set_is_partial(is_partial_);
|
c->req->set_is_partial(is_partial_);
|
||||||
c->req->set_is_last_partial_run(is_last_partial_run);
|
c->req->set_is_last_partial_run(is_last_partial_run);
|
||||||
}
|
}
|
||||||
|
c->req->set_session_handle(session_handle_);
|
||||||
c->req->set_graph_handle(part.graph_handle);
|
c->req->set_graph_handle(part.graph_handle);
|
||||||
c->req->set_step_id(step_id);
|
c->req->set_step_id(step_id);
|
||||||
*c->req->mutable_exec_opts() = exec_opts;
|
*c->req->mutable_exec_opts() = exec_opts;
|
||||||
@ -870,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
|
|||||||
// The graph handle may be empty if we failed during partition registration.
|
// The graph handle may be empty if we failed during partition registration.
|
||||||
if (!part.graph_handle.empty()) {
|
if (!part.graph_handle.empty()) {
|
||||||
Call* c = new Call;
|
Call* c = new Call;
|
||||||
|
c->req.set_session_handle(session_handle_);
|
||||||
c->req.set_graph_handle(part.graph_handle);
|
c->req.set_graph_handle(part.graph_handle);
|
||||||
// NOTE(mrry): We must capture `worker_cache_` since `this`
|
// NOTE(mrry): We must capture `worker_cache_` since `this`
|
||||||
// could be deleted before the callback is called.
|
// could be deleted before the callback is called.
|
||||||
@ -972,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
|
|||||||
MasterSession::MasterSession(
|
MasterSession::MasterSession(
|
||||||
const SessionOptions& opt, const MasterEnv* env,
|
const SessionOptions& opt, const MasterEnv* env,
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||||
|
std::unique_ptr<DeviceSet> device_set,
|
||||||
StatsPublisherFactory stats_publisher_factory)
|
StatsPublisherFactory stats_publisher_factory)
|
||||||
: session_opts_(opt),
|
: session_opts_(opt),
|
||||||
env_(env),
|
env_(env),
|
||||||
handle_(strings::FpToString(random::New64())),
|
handle_(strings::FpToString(random::New64())),
|
||||||
remote_devs_(std::move(remote_devs)),
|
remote_devs_(std::move(remote_devs)),
|
||||||
|
worker_cache_(std::move(worker_cache)),
|
||||||
|
devices_(std::move(device_set)),
|
||||||
stats_publisher_factory_(std::move(stats_publisher_factory)),
|
stats_publisher_factory_(std::move(stats_publisher_factory)),
|
||||||
graph_version_(0),
|
graph_version_(0),
|
||||||
run_graphs_(5),
|
run_graphs_(5),
|
||||||
partial_run_graphs_(5) {
|
partial_run_graphs_(5) {
|
||||||
UpdateLastAccessTime();
|
UpdateLastAccessTime();
|
||||||
|
CHECK(devices_) << "device_set was null!";
|
||||||
|
|
||||||
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
|
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
|
||||||
<< " #remote " << remote_devs_->size();
|
<< " #remote " << remote_devs_->size();
|
||||||
for (auto&& d : *remote_devs_) {
|
|
||||||
devices_.AddDevice(d.get());
|
|
||||||
}
|
|
||||||
int num_local_devices = 0;
|
|
||||||
for (Device* d : env->local_devices) {
|
|
||||||
devices_.AddDevice(d);
|
|
||||||
if (num_local_devices == 0) {
|
|
||||||
// Uses the first local device as the client device.
|
|
||||||
devices_.set_client_device(d);
|
|
||||||
}
|
|
||||||
num_local_devices++;
|
|
||||||
}
|
|
||||||
LOG(INFO) << "Start master session " << handle_
|
LOG(INFO) << "Start master session " << handle_
|
||||||
<< " with config: " << std::endl
|
<< " with config: " << std::endl
|
||||||
<< session_opts_.config.DebugString();
|
<< session_opts_.config.DebugString();
|
||||||
@ -1011,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
|
|||||||
last_access_time_usec_.store(Env::Default()->NowMicros());
|
last_access_time_usec_.store(Env::Default()->NowMicros());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MasterSession::Create(GraphDef* graph_def) {
|
Status MasterSession::Create(GraphDef* graph_def,
|
||||||
|
const WorkerCacheFactoryOptions& options) {
|
||||||
if (session_opts_.config.graph_options().place_pruned_graph()) {
|
if (session_opts_.config.graph_options().place_pruned_graph()) {
|
||||||
// TODO(b/29900832): Fix this or remove the option.
|
// TODO(b/29900832): Fix this or remove the option.
|
||||||
LOG(WARNING) << "Distributed session does not support the "
|
LOG(WARNING) << "Distributed session does not support the "
|
||||||
@ -1019,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
|
|||||||
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
|
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleGraphExecutionStateOptions options;
|
SimpleGraphExecutionStateOptions execution_options;
|
||||||
options.device_set = &devices_;
|
execution_options.device_set = devices_.get();
|
||||||
options.session_options = &session_opts_;
|
execution_options.session_options = &session_opts_;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
|
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
|
||||||
graph_def, options, &execution_state_));
|
graph_def, execution_options, &execution_state_));
|
||||||
|
}
|
||||||
|
if (options.cluster_def != nullptr) {
|
||||||
|
return CreateWorkerSessions(options);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MasterSession::CreateWorkerSessions(
|
||||||
|
const WorkerCacheFactoryOptions& options) {
|
||||||
|
CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
|
||||||
|
<< "dynamic cluster membership.";
|
||||||
|
std::vector<string> worker_names;
|
||||||
|
worker_cache_->ListWorkers(&worker_names);
|
||||||
|
|
||||||
|
struct WorkerGroup {
|
||||||
|
// The worker name. (Not owned.)
|
||||||
|
const string* name;
|
||||||
|
|
||||||
|
// The worker referenced by name. (Not owned.)
|
||||||
|
WorkerInterface* worker = nullptr;
|
||||||
|
|
||||||
|
// Request and responses used for a given worker.
|
||||||
|
CreateWorkerSessionRequest request;
|
||||||
|
CreateWorkerSessionResponse response;
|
||||||
|
Status status = Status::OK();
|
||||||
|
};
|
||||||
|
BlockingCounter done(worker_names.size());
|
||||||
|
std::vector<WorkerGroup> workers(worker_names.size());
|
||||||
|
|
||||||
|
// Release the workers.
|
||||||
|
auto cleanup = gtl::MakeCleanup([this, &workers] {
|
||||||
|
for (auto&& worker_group : workers) {
|
||||||
|
if (worker_group.worker != nullptr) {
|
||||||
|
worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Status status = Status::OK();
|
||||||
|
// Create all the workers & kick off the computations.
|
||||||
|
for (size_t i = 0; i < worker_names.size(); ++i) {
|
||||||
|
workers[i].name = &worker_names[i];
|
||||||
|
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
|
||||||
|
workers[i].request.set_session_handle(handle_);
|
||||||
|
*workers[i].request.mutable_server_def()->mutable_cluster() =
|
||||||
|
*options.cluster_def;
|
||||||
|
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
|
||||||
|
|
||||||
|
DeviceNameUtils::ParsedName name;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
|
||||||
|
status = errors::Internal("Could not parse name ", worker_names[i]);
|
||||||
|
LOG(WARNING) << status;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
if (!name.has_job || !name.has_task) {
|
||||||
|
status = errors::Internal("Incomplete worker name ", worker_names[i]);
|
||||||
|
LOG(WARNING) << status;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
workers[i].request.mutable_server_def()->set_job_name(name.job);
|
||||||
|
workers[i].request.mutable_server_def()->set_task_index(name.task);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < worker_names.size(); ++i) {
|
||||||
|
auto cb = [i, &workers, &done](const Status& s) {
|
||||||
|
workers[i].status = s;
|
||||||
|
done.DecrementCount();
|
||||||
|
};
|
||||||
|
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
|
||||||
|
&workers[i].response, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
done.Wait();
|
||||||
|
for (size_t i = 0; i < workers.size(); ++i) {
|
||||||
|
status.Update(workers[i].status);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
Status MasterSession::Extend(const ExtendSessionRequest* req,
|
Status MasterSession::Extend(const ExtendSessionRequest* req,
|
||||||
ExtendSessionResponse* resp) {
|
ExtendSessionResponse* resp) {
|
||||||
UpdateLastAccessTime();
|
UpdateLastAccessTime();
|
||||||
@ -1059,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WorkerCacheInterface* MasterSession::get_worker_cache() const {
|
||||||
|
if (worker_cache_) {
|
||||||
|
return worker_cache_.get();
|
||||||
|
}
|
||||||
|
return env_->worker_cache;
|
||||||
|
}
|
||||||
|
|
||||||
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
|
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
|
||||||
ReffedClientGraph** rcg, bool is_partial) {
|
ReffedClientGraph** rcg, bool is_partial) {
|
||||||
const uint64 hash = HashBuildGraphOptions(opts);
|
const uint64 hash = HashBuildGraphOptions(opts);
|
||||||
@ -1082,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
|
|||||||
<< "\n";
|
<< "\n";
|
||||||
std::unique_ptr<SimpleClientGraph> client_graph;
|
std::unique_ptr<SimpleClientGraph> client_graph;
|
||||||
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
|
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
|
||||||
|
WorkerCacheInterface* worker_cache = get_worker_cache();
|
||||||
auto entry = new ReffedClientGraph(
|
auto entry = new ReffedClientGraph(
|
||||||
handle_, opts, std::move(client_graph), session_opts_,
|
handle_, opts, std::move(client_graph), session_opts_,
|
||||||
stats_publisher_factory_, execution_state_.get(), is_partial,
|
stats_publisher_factory_, execution_state_.get(), is_partial,
|
||||||
env_->worker_cache);
|
worker_cache);
|
||||||
|
|
||||||
iter = m->insert({hash, entry}).first;
|
iter = m->insert({hash, entry}).first;
|
||||||
VLOG(1) << "Preparing to execute new graph";
|
VLOG(1) << "Preparing to execute new graph";
|
||||||
}
|
}
|
||||||
@ -1161,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
|
|||||||
return errors::FailedPrecondition("Session is closed.");
|
return errors::FailedPrecondition("Session is closed.");
|
||||||
}
|
}
|
||||||
++num_running_;
|
++num_running_;
|
||||||
|
// Note: all code paths must eventually call MarkRunCompletion()
|
||||||
|
// in order to appropriate decrement the num_running_ counter.
|
||||||
}
|
}
|
||||||
Status status;
|
Status status;
|
||||||
if (!req.partial_run_handle().empty()) {
|
if (!req.partial_run_handle().empty()) {
|
||||||
@ -1168,15 +1253,17 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
|
|||||||
} else {
|
} else {
|
||||||
status = DoRunWithLocalExecution(opts, req, resp);
|
status = DoRunWithLocalExecution(opts, req, resp);
|
||||||
}
|
}
|
||||||
{
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrements num_running_ and broadcasts if num_running_ is zero.
|
||||||
|
void MasterSession::MarkRunCompletion() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
--num_running_;
|
--num_running_;
|
||||||
if (num_running_ == 0) {
|
if (num_running_ == 0) {
|
||||||
num_running_is_zero_.notify_all();
|
num_running_is_zero_.notify_all();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||||
// Registers subgraphs if haven't done so.
|
// Registers subgraphs if haven't done so.
|
||||||
@ -1187,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
|||||||
return strings::StrCat(prefix, "_S", next_node_id_++);
|
return strings::StrCat(prefix, "_S", next_node_id_++);
|
||||||
};
|
};
|
||||||
popts.get_incarnation = [this](const string& name) -> int64 {
|
popts.get_incarnation = [this](const string& name) -> int64 {
|
||||||
Device* d = devices_.FindDeviceByName(name);
|
Device* d = devices_->FindDeviceByName(name);
|
||||||
if (d == nullptr) {
|
if (d == nullptr) {
|
||||||
return PartitionOptions::kIllegalIncarnation;
|
return PartitionOptions::kIllegalIncarnation;
|
||||||
} else {
|
} else {
|
||||||
@ -1214,7 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto()));
|
rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1222,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
|||||||
Status MasterSession::DoPartialRun(CallOptions* opts,
|
Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||||
const RunStepRequestWrapper& req,
|
const RunStepRequestWrapper& req,
|
||||||
MutableRunStepResponseWrapper* resp) {
|
MutableRunStepResponseWrapper* resp) {
|
||||||
|
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
|
||||||
const string& prun_handle = req.partial_run_handle();
|
const string& prun_handle = req.partial_run_handle();
|
||||||
RunState* run_state = nullptr;
|
RunState* run_state = nullptr;
|
||||||
{
|
{
|
||||||
@ -1320,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
|||||||
rcg->Ref();
|
rcg->Ref();
|
||||||
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
|
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
|
||||||
req.options(), resp->mutable_metadata());
|
req.options(), resp->mutable_metadata());
|
||||||
|
cleanup.release(); // MarkRunCompletion called in done closure.
|
||||||
rcg->CleanupPartitionsAsync(
|
rcg->CleanupPartitionsAsync(
|
||||||
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
|
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||||
}
|
}
|
||||||
rcg->Unref();
|
rcg->Unref();
|
||||||
|
MarkRunCompletion();
|
||||||
});
|
});
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
partial_runs_.erase(prun_handle);
|
partial_runs_.erase(prun_handle);
|
||||||
@ -1367,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
|
|||||||
Status MasterSession::DoRunWithLocalExecution(
|
Status MasterSession::DoRunWithLocalExecution(
|
||||||
CallOptions* opts, const RunStepRequestWrapper& req,
|
CallOptions* opts, const RunStepRequestWrapper& req,
|
||||||
MutableRunStepResponseWrapper* resp) {
|
MutableRunStepResponseWrapper* resp) {
|
||||||
VLOG(2) << "DoRunWithLocalExecution "
|
VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
|
||||||
<< "req: " << req.DebugString();
|
|
||||||
PerStepState pss;
|
PerStepState pss;
|
||||||
pss.start_micros = Env::Default()->NowMicros();
|
pss.start_micros = Env::Default()->NowMicros();
|
||||||
|
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
|
||||||
|
|
||||||
// Prepare.
|
// Prepare.
|
||||||
BuildGraphOptions bgopts;
|
BuildGraphOptions bgopts;
|
||||||
@ -1437,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
rcg->Ref();
|
rcg->Ref();
|
||||||
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
|
cleanup.release(); // MarkRunCompletion called in done closure.
|
||||||
|
rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||||
}
|
}
|
||||||
rcg->Unref();
|
rcg->Unref();
|
||||||
|
MarkRunCompletion();
|
||||||
});
|
});
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||||
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/protobuf/master.pb.h"
|
#include "tensorflow/core/protobuf/master.pb.h"
|
||||||
@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
|
|||||||
MasterSession(
|
MasterSession(
|
||||||
const SessionOptions& options, const MasterEnv* env,
|
const SessionOptions& options, const MasterEnv* env,
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||||
|
std::unique_ptr<DeviceSet> device_set,
|
||||||
StatsPublisherFactory stats_publisher_factory);
|
StatsPublisherFactory stats_publisher_factory);
|
||||||
|
|
||||||
// Initialize the MasterSession for "def". Must be called before Extend(),
|
// Initialize the MasterSession for "def". Must be called before Extend(),
|
||||||
// Run(), or Close().
|
// Run(), or Close().
|
||||||
//
|
//
|
||||||
// After this method returns, `def` will no longer be valid.
|
// After this method returns, `def` will no longer be valid.
|
||||||
Status Create(GraphDef* def);
|
Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
|
||||||
|
|
||||||
// Returns the session handle.
|
// Returns the session handle.
|
||||||
const string& handle() const { return handle_; }
|
const string& handle() const { return handle_; }
|
||||||
@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
|
|||||||
|
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
|
||||||
|
|
||||||
|
// The optional session-specific worker cluster.
|
||||||
|
// TODO(saeta): Convert to std::optional when available.
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache_;
|
||||||
|
// Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
|
||||||
|
WorkerCacheInterface* get_worker_cache() const;
|
||||||
|
|
||||||
// The device set used by this session.
|
// The device set used by this session.
|
||||||
DeviceSet devices_;
|
std::unique_ptr<DeviceSet> devices_;
|
||||||
|
|
||||||
StatsPublisherFactory stats_publisher_factory_;
|
StatsPublisherFactory stats_publisher_factory_;
|
||||||
|
|
||||||
@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
|
|||||||
// Private dtor. The client must call Close().
|
// Private dtor. The client must call Close().
|
||||||
virtual ~MasterSession();
|
virtual ~MasterSession();
|
||||||
|
|
||||||
|
// Creates sessions on all workers.
|
||||||
|
//
|
||||||
|
// If this session is operating using the new ClusterSpec propagation behavior
|
||||||
|
// call this method in order to propagate the cluster membership to all
|
||||||
|
// workers.
|
||||||
|
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
|
||||||
|
|
||||||
Status StartStep(const BuildGraphOptions& opts, int64* count,
|
Status StartStep(const BuildGraphOptions& opts, int64* count,
|
||||||
ReffedClientGraph** graph, bool is_partial);
|
ReffedClientGraph** graph, bool is_partial);
|
||||||
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||||
@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
|
|||||||
MutableRunStepResponseWrapper* resp);
|
MutableRunStepResponseWrapper* resp);
|
||||||
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
|
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
|
||||||
MutableRunStepResponseWrapper* resp);
|
MutableRunStepResponseWrapper* resp);
|
||||||
|
void MarkRunCompletion();
|
||||||
void UpdateLastAccessTime();
|
void UpdateLastAccessTime();
|
||||||
|
|
||||||
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
|
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
|
||||||
|
@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
|
|||||||
|
|
||||||
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
|
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
|
||||||
|
|
||||||
|
const string& InMemoryRunGraphRequest::session_handle() const {
|
||||||
|
return session_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
|
||||||
|
session_handle_ = handle;
|
||||||
|
}
|
||||||
|
|
||||||
const string& InMemoryRunGraphRequest::graph_handle() const {
|
const string& InMemoryRunGraphRequest::graph_handle() const {
|
||||||
return graph_handle_;
|
return graph_handle_;
|
||||||
}
|
}
|
||||||
@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
|
|||||||
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
|
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
|
||||||
if (!proto_version_) {
|
if (!proto_version_) {
|
||||||
proto_version_.reset(new RunGraphRequest);
|
proto_version_.reset(new RunGraphRequest);
|
||||||
|
proto_version_->set_session_handle(session_handle());
|
||||||
proto_version_->set_graph_handle(graph_handle());
|
proto_version_->set_graph_handle(graph_handle());
|
||||||
proto_version_->set_step_id(step_id());
|
proto_version_->set_step_id(step_id());
|
||||||
*proto_version_->mutable_exec_opts() = exec_opts();
|
*proto_version_->mutable_exec_opts() = exec_opts();
|
||||||
@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
|
|||||||
return *proto_version_;
|
return *proto_version_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const string& MutableProtoRunGraphRequest::session_handle() const {
|
||||||
|
return request_.session_handle();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
|
||||||
|
request_.set_session_handle(handle);
|
||||||
|
}
|
||||||
|
|
||||||
const string& MutableProtoRunGraphRequest::graph_handle() const {
|
const string& MutableProtoRunGraphRequest::graph_handle() const {
|
||||||
return request_.graph_handle();
|
return request_.graph_handle();
|
||||||
}
|
}
|
||||||
@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
|
|||||||
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
|
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
|
||||||
: request_(request) {}
|
: request_(request) {}
|
||||||
|
|
||||||
|
const string& ProtoRunGraphRequest::session_handle() const {
|
||||||
|
return request_->session_handle();
|
||||||
|
}
|
||||||
|
|
||||||
const string& ProtoRunGraphRequest::graph_handle() const {
|
const string& ProtoRunGraphRequest::graph_handle() const {
|
||||||
return request_->graph_handle();
|
return request_->graph_handle();
|
||||||
}
|
}
|
||||||
|
@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
|
|||||||
public:
|
public:
|
||||||
virtual ~RunGraphRequestWrapper() {}
|
virtual ~RunGraphRequestWrapper() {}
|
||||||
|
|
||||||
|
// The session handle used to register the graph. If empty, a single global
|
||||||
|
// namespace is used.
|
||||||
|
virtual const string& session_handle() const = 0;
|
||||||
|
|
||||||
// REQUIRED: graph_handle must be returned by a RegisterGraph call
|
// REQUIRED: graph_handle must be returned by a RegisterGraph call
|
||||||
// to the same WorkerService.
|
// to the same WorkerService.
|
||||||
virtual const string& graph_handle() const = 0;
|
virtual const string& graph_handle() const = 0;
|
||||||
@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
|
|||||||
// See `RunGraphRequestWrapper` above for a description of the fields.
|
// See `RunGraphRequestWrapper` above for a description of the fields.
|
||||||
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
|
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
|
||||||
public:
|
public:
|
||||||
|
virtual void set_session_handle(const string& handle) = 0;
|
||||||
virtual void set_graph_handle(const string& handle) = 0;
|
virtual void set_graph_handle(const string& handle) = 0;
|
||||||
virtual void set_step_id(int64 step_id) = 0;
|
virtual void set_step_id(int64 step_id) = 0;
|
||||||
virtual ExecutorOpts* mutable_exec_opts() = 0;
|
virtual ExecutorOpts* mutable_exec_opts() = 0;
|
||||||
@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
|
|||||||
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||||
public:
|
public:
|
||||||
// RunGraphRequestWrapper methods.
|
// RunGraphRequestWrapper methods.
|
||||||
|
const string& session_handle() const override;
|
||||||
const string& graph_handle() const override;
|
const string& graph_handle() const override;
|
||||||
int64 step_id() const override;
|
int64 step_id() const override;
|
||||||
const ExecutorOpts& exec_opts() const override;
|
const ExecutorOpts& exec_opts() const override;
|
||||||
@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
|||||||
const RunGraphRequest& ToProto() const override;
|
const RunGraphRequest& ToProto() const override;
|
||||||
|
|
||||||
// MutableRunGraphRequestWrapper methods.
|
// MutableRunGraphRequestWrapper methods.
|
||||||
|
void set_session_handle(const string& handle) override;
|
||||||
void set_graph_handle(const string& handle) override;
|
void set_graph_handle(const string& handle) override;
|
||||||
void set_step_id(int64 step_id) override;
|
void set_step_id(int64 step_id) override;
|
||||||
ExecutorOpts* mutable_exec_opts() override;
|
ExecutorOpts* mutable_exec_opts() override;
|
||||||
@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
|||||||
void set_is_last_partial_run(bool is_last_partial_run) override;
|
void set_is_last_partial_run(bool is_last_partial_run) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
string session_handle_;
|
||||||
string graph_handle_;
|
string graph_handle_;
|
||||||
int64 step_id_;
|
int64 step_id_;
|
||||||
ExecutorOpts exec_opts_;
|
ExecutorOpts exec_opts_;
|
||||||
@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
|||||||
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
|
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||||
public:
|
public:
|
||||||
// RunGraphRequestWrapper methods.
|
// RunGraphRequestWrapper methods.
|
||||||
|
const string& session_handle() const override;
|
||||||
const string& graph_handle() const override;
|
const string& graph_handle() const override;
|
||||||
int64 step_id() const override;
|
int64 step_id() const override;
|
||||||
const ExecutorOpts& exec_opts() const override;
|
const ExecutorOpts& exec_opts() const override;
|
||||||
@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
|
|||||||
const RunGraphRequest& ToProto() const override;
|
const RunGraphRequest& ToProto() const override;
|
||||||
|
|
||||||
// MutableRunGraphRequestWrapper methods.
|
// MutableRunGraphRequestWrapper methods.
|
||||||
|
void set_session_handle(const string& handle) override;
|
||||||
void set_graph_handle(const string& handle) override;
|
void set_graph_handle(const string& handle) override;
|
||||||
void set_step_id(int64 step_id) override;
|
void set_step_id(int64 step_id) override;
|
||||||
ExecutorOpts* mutable_exec_opts() override;
|
ExecutorOpts* mutable_exec_opts() override;
|
||||||
@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
|
|||||||
ProtoRunGraphRequest(const RunGraphRequest* request);
|
ProtoRunGraphRequest(const RunGraphRequest* request);
|
||||||
|
|
||||||
// RunGraphRequestWrapper methods.
|
// RunGraphRequestWrapper methods.
|
||||||
|
const string& session_handle() const override;
|
||||||
const string& graph_handle() const override;
|
const string& graph_handle() const override;
|
||||||
int64 step_id() const override;
|
int64 step_id() const override;
|
||||||
const ExecutorOpts& exec_opts() const override;
|
const ExecutorOpts& exec_opts() const override;
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||||
@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
|
|||||||
class RemoteDevice : public Device {
|
class RemoteDevice : public Device {
|
||||||
public:
|
public:
|
||||||
RemoteDevice(Env* env, const DeviceAttributes& da)
|
RemoteDevice(Env* env, const DeviceAttributes& da)
|
||||||
: Device(env, da, nullptr),
|
: Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
|
||||||
local_dev_name_(GetLocalDeviceName(da.name())) {}
|
|
||||||
|
|
||||||
Status Sync() override { return Status::OK(); }
|
Status Sync() override { return Status::OK(); }
|
||||||
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
|
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
|
||||||
@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
|
|||||||
GetStatusResponse resp;
|
GetStatusResponse resp;
|
||||||
};
|
};
|
||||||
Call* call = new Call;
|
Call* call = new Call;
|
||||||
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
|
auto cb = [env, worker_cache, worker_name, done, wi,
|
||||||
|
call](const Status& status) {
|
||||||
|
Status s = status;
|
||||||
std::vector<Device*> remote_devices;
|
std::vector<Device*> remote_devices;
|
||||||
if (s.ok()) {
|
auto cleanup = gtl::MakeCleanup(
|
||||||
remote_devices.reserve(call->resp.device_attributes_size());
|
[&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
|
||||||
for (const DeviceAttributes& da : call->resp.device_attributes()) {
|
|
||||||
auto d = new RemoteDevice(env, da);
|
|
||||||
remote_devices.push_back(d);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
worker_cache->ReleaseWorker(worker_name, wi);
|
worker_cache->ReleaseWorker(worker_name, wi);
|
||||||
done(s, &remote_devices);
|
done(s, &remote_devices);
|
||||||
delete call;
|
delete call;
|
||||||
|
});
|
||||||
|
if (s.ok()) {
|
||||||
|
DeviceNameUtils::ParsedName worker_name_parsed;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
|
||||||
|
!worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
|
||||||
|
!worker_name_parsed.has_task) {
|
||||||
|
s = errors::InvalidArgument("Could not parse worker name: ",
|
||||||
|
worker_name);
|
||||||
|
LOG(WARNING) << s;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
remote_devices.reserve(call->resp.device_attributes_size());
|
||||||
|
for (const DeviceAttributes& da : call->resp.device_attributes()) {
|
||||||
|
DeviceNameUtils::ParsedName device_name_parsed;
|
||||||
|
CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
|
||||||
|
<< "Device attribute name '" << da.name() << "' could not be "
|
||||||
|
<< "parsed. Device Attribute: " << da.DebugString();
|
||||||
|
// Preserve the exact name, if possible.
|
||||||
|
// TODO(b/37868888): Simplify when legacy device name formats removed.
|
||||||
|
if (device_name_parsed.job == worker_name_parsed.job &&
|
||||||
|
device_name_parsed.replica == worker_name_parsed.replica &&
|
||||||
|
device_name_parsed.task == worker_name_parsed.task) {
|
||||||
|
auto d = new RemoteDevice(env, da);
|
||||||
|
remote_devices.push_back(d);
|
||||||
|
} else {
|
||||||
|
DeviceAttributes da_rewritten = da;
|
||||||
|
da_rewritten.set_name(DeviceNameUtils::FullName(
|
||||||
|
worker_name_parsed.job, worker_name_parsed.replica,
|
||||||
|
worker_name_parsed.task, device_name_parsed.type,
|
||||||
|
device_name_parsed.id));
|
||||||
|
auto d = new RemoteDevice(env, da_rewritten);
|
||||||
|
remote_devices.push_back(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
wi->GetStatusAsync(&call->req, &call->resp, cb);
|
wi->GetStatusAsync(&call->req, &call->resp, cb);
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,23 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
struct WorkerSession;
|
||||||
|
|
||||||
|
// RemoteRendezvous follow a 2-part initialization. First the objects are
|
||||||
|
// constructed. Eventually, they will be initialized. Clients of the
|
||||||
|
// RendezvousMgrInterface must guarantee to call Initialize on the returned
|
||||||
|
// RemoteRendezvous eventually.
|
||||||
|
//
|
||||||
|
// Partially initialized RemoteRendezvous must respect the Rendezvous interface
|
||||||
|
// (i.e. Send() must never block), however implementations are not expected to
|
||||||
|
// actually perform the underlying operations until after the RemoteRendezvous
|
||||||
|
// has been Initialize'd.
|
||||||
|
class RemoteRendezvous : public Rendezvous {
|
||||||
|
public:
|
||||||
|
// Fully construct the RemoteRendezvous.
|
||||||
|
virtual Status Initialize(WorkerSession* session) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||||
// until the tensor is received. Each global unique "step_id"
|
// until the tensor is received. Each global unique "step_id"
|
||||||
@ -51,7 +68,10 @@ class RendezvousMgrInterface {
|
|||||||
// Returns Rendezvous supporting send and recv among workers in the
|
// Returns Rendezvous supporting send and recv among workers in the
|
||||||
// "step_id". The caller takes ownership of one reference on the
|
// "step_id". The caller takes ownership of one reference on the
|
||||||
// returned Rendezvous instance.
|
// returned Rendezvous instance.
|
||||||
virtual Rendezvous* Find(int64 step_id) = 0;
|
//
|
||||||
|
// Note: the caller must guarantee to eventually call Initialize on the
|
||||||
|
// returned RemoteRendezvous
|
||||||
|
virtual RemoteRendezvous* Find(int64 step_id) = 0;
|
||||||
|
|
||||||
// Finds the local rendezvous instance for the "step_id". Runs
|
// Finds the local rendezvous instance for the "step_id". Runs
|
||||||
// "done" when the tensor for "key" is produced or an error occurs.
|
// "done" when the tensor for "key" is produced or an error occurs.
|
||||||
|
@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// static utility function
|
// static utility function
|
||||||
RendezvousMgrInterface* NewRpcRendezvousMgr(
|
RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
|
||||||
const WorkerEnv* env, const string& worker_name,
|
return new RpcRendezvousMgr(env);
|
||||||
WorkerCacheInterface* worker_cache) {
|
|
||||||
return new RpcRendezvousMgr(env, worker_name, worker_cache);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
|
|||||||
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
|
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
|
||||||
// to destroy them.
|
// to destroy them.
|
||||||
|
|
||||||
|
// Shut down all outstanding rendezvous.
|
||||||
|
delete worker_env_.rendezvous_mgr;
|
||||||
|
|
||||||
// We must delete graph_mgr before device_mgr, due to shared
|
// We must delete graph_mgr before device_mgr, due to shared
|
||||||
// ownership of OpKernels in the executors. (The graph_mgr will
|
// ownership of OpKernels in the executors. (The graph_mgr will
|
||||||
// free all stateless OpKernels, and pass over borrowed stateful
|
// free all stateless OpKernels, and pass over borrowed stateful
|
||||||
@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
|
|||||||
// OpSegments.)
|
// OpSegments.)
|
||||||
if (worker_env_.session_mgr != nullptr) {
|
if (worker_env_.session_mgr != nullptr) {
|
||||||
delete worker_env_.session_mgr; // Deletes graph_mgr's.
|
delete worker_env_.session_mgr; // Deletes graph_mgr's.
|
||||||
}
|
} else {
|
||||||
|
// Note: session_mgr's legacy_session_ deletes device_mgr now.
|
||||||
delete worker_env_.device_mgr;
|
delete worker_env_.device_mgr;
|
||||||
|
}
|
||||||
|
|
||||||
// Do not delete (as these are not owned by the server):
|
// Do not delete (as these are not owned by the server):
|
||||||
// - master_env_.env
|
// - master_env_.env
|
||||||
@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
|
|||||||
// - worker_env_.compute_pool
|
// - worker_env_.compute_pool
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GrpcServer::Init(ServiceInitFunction service_func,
|
Status GrpcServer::Init(
|
||||||
RendezvousMgrCreationFunction rendevous_mgr_func) {
|
ServiceInitFunction service_func,
|
||||||
|
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
CHECK_EQ(state_, NEW);
|
CHECK_EQ(state_, NEW);
|
||||||
master_env_.env = env_;
|
master_env_.env = env_;
|
||||||
@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
|||||||
"/task:", server_def_.task_index());
|
"/task:", server_def_.task_index());
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
|
||||||
&master_env_.local_devices));
|
&master_env_.local_devices));
|
||||||
worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices);
|
worker_env_.local_devices = master_env_.local_devices;
|
||||||
|
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
|
||||||
|
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
|
||||||
|
? new RpcRendezvousMgr(&worker_env_)
|
||||||
|
: rendezvous_mgr_func(&worker_env_);
|
||||||
string unused;
|
string unused;
|
||||||
string default_worker_name;
|
string default_worker_name;
|
||||||
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
|
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
|
||||||
@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
|||||||
}
|
}
|
||||||
|
|
||||||
WorkerCacheInterface* worker_cache;
|
WorkerCacheInterface* worker_cache;
|
||||||
TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache));
|
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
|
||||||
CHECK_NE(nullptr, worker_cache);
|
CHECK_NE(nullptr, worker_cache);
|
||||||
|
|
||||||
// Set up worker environment.
|
// Set up worker environment.
|
||||||
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
|
|
||||||
rendevous_mgr_func == nullptr ?
|
|
||||||
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
|
|
||||||
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
|
|
||||||
worker_env_.session_mgr = new SessionMgr(
|
worker_env_.session_mgr = new SessionMgr(
|
||||||
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
|
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
|
||||||
std::unique_ptr<WorkerCacheInterface>(worker_cache),
|
std::unique_ptr<WorkerCacheInterface>(worker_cache),
|
||||||
std::move(rendezvous_mgr),
|
|
||||||
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
|
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
|
||||||
return WorkerCacheFactory(server_def, worker_cache);
|
WorkerCacheFactoryOptions options(server_def);
|
||||||
|
return WorkerCacheFactory(options, worker_cache);
|
||||||
});
|
});
|
||||||
worker_env_.compute_pool = ComputePool(sess_opts);
|
worker_env_.compute_pool = ComputePool(sess_opts);
|
||||||
|
|
||||||
@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
|||||||
master_env_.master_session_factory =
|
master_env_.master_session_factory =
|
||||||
[config](
|
[config](
|
||||||
SessionOptions options, const MasterEnv* env,
|
SessionOptions options, const MasterEnv* env,
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) {
|
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||||
|
std::unique_ptr<DeviceSet> device_set) {
|
||||||
options.config.MergeFrom(config);
|
options.config.MergeFrom(config);
|
||||||
return new MasterSession(options, env, std::move(remote_devs),
|
return new MasterSession(options, env, std::move(remote_devs),
|
||||||
|
std::move(worker_cache), std::move(device_set),
|
||||||
CreateNoOpStatsPublisher);
|
CreateNoOpStatsPublisher);
|
||||||
};
|
};
|
||||||
|
master_env_.worker_cache_factory =
|
||||||
|
[this](const WorkerCacheFactoryOptions& options,
|
||||||
|
WorkerCacheInterface** worker_cache) {
|
||||||
|
return WorkerCacheFactory(options, worker_cache);
|
||||||
|
};
|
||||||
|
|
||||||
// Provide direct access to the master from in-process clients.
|
// Provide direct access to the master from in-process clients.
|
||||||
LocalMaster::Register(target(), master_impl_.get(),
|
LocalMaster::Register(target(), master_impl_.get(),
|
||||||
@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GrpcServer::Init() {
|
Status GrpcServer::Init() { return Init(nullptr, nullptr); }
|
||||||
return Init(nullptr, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
|
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||||
GrpcChannelSpec* channel_spec) {
|
GrpcChannelSpec* channel_spec) {
|
||||||
for (const auto& job : server_def.cluster().job()) {
|
for (const auto& job : options.cluster_def->job()) {
|
||||||
std::map<int, string> host_ports;
|
std::map<int, string> host_ports;
|
||||||
for (const auto& task : job.tasks()) {
|
for (const auto& task : job.tasks()) {
|
||||||
string& host_port = host_ports[task.first];
|
string& host_port = host_ports[task.first];
|
||||||
@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
|
|||||||
task.first, "\": ", host_port, " and ",
|
task.first, "\": ", host_port, " and ",
|
||||||
task.second);
|
task.second);
|
||||||
}
|
}
|
||||||
if (job.name() == server_def.job_name() &&
|
if (job.name() == *options.job_name && task.first == options.task_index) {
|
||||||
task.first == server_def.task_index()) {
|
|
||||||
host_port = strings::StrCat("localhost:", bound_port_);
|
host_port = strings::StrCat("localhost:", bound_port_);
|
||||||
} else {
|
} else {
|
||||||
host_port = task.second;
|
host_port = task.second;
|
||||||
@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
|
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||||
WorkerCacheInterface** worker_cache) {
|
WorkerCacheInterface** worker_cache) {
|
||||||
string name_prefix =
|
if (options.job_name == nullptr || options.job_name->empty()) {
|
||||||
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
|
Status s = errors::InvalidArgument(
|
||||||
"/task:", server_def.task_index());
|
"The master (current machine) is not included in the provided "
|
||||||
|
"cluster_def. ",
|
||||||
|
options.cluster_def->DebugString());
|
||||||
|
LOG(WARNING) << s;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
GrpcChannelSpec channel_spec;
|
GrpcChannelSpec channel_spec;
|
||||||
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
|
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
|
||||||
|
|
||||||
|
std::unique_ptr<GrpcChannelCache> channel_cache(
|
||||||
|
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
|
||||||
|
|
||||||
|
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
|
||||||
|
"/task:", options.task_index);
|
||||||
|
|
||||||
std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
|
|
||||||
channel_spec, GetChannelCreationFunction(server_def)));
|
|
||||||
const string host_port = channel_cache->TranslateTask(name_prefix);
|
const string host_port = channel_cache->TranslateTask(name_prefix);
|
||||||
int requested_port;
|
int requested_port;
|
||||||
|
|
||||||
@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
|
|||||||
return ::grpc::InsecureServerCredentials();
|
return ::grpc::InsecureServerCredentials();
|
||||||
}
|
}
|
||||||
|
|
||||||
ChannelCreationFunction GrpcServer::GetChannelCreationFunction(
|
ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
|
||||||
const ServerDef& server_def) const {
|
|
||||||
// We can do this because SparseGrpcChannelCache is robust to nullptr being
|
// We can do this because SparseGrpcChannelCache is robust to nullptr being
|
||||||
// returned by the channel creation function
|
// returned by the channel creation function
|
||||||
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
|
@ -37,9 +37,7 @@ class GrpcWorker;
|
|||||||
class Master;
|
class Master;
|
||||||
|
|
||||||
// function that creates a RendezvousMgr.
|
// function that creates a RendezvousMgr.
|
||||||
typedef std::function<RendezvousMgrInterface*(
|
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
|
||||||
const WorkerEnv*, const std::string& worker_name,
|
|
||||||
WorkerCacheInterface* worker_cache)>
|
|
||||||
RendezvousMgrCreationFunction;
|
RendezvousMgrCreationFunction;
|
||||||
|
|
||||||
// function that registers a service to the server. The service needs to
|
// function that registers a service to the server. The service needs to
|
||||||
@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status Init(ServiceInitFunction service_func,
|
Status Init(ServiceInitFunction service_func,
|
||||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
|
||||||
|
|
||||||
Status Init();
|
Status Init();
|
||||||
|
|
||||||
@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
|
|||||||
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
|
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
|
||||||
const ServerDef& server_def) const;
|
const ServerDef& server_def) const;
|
||||||
|
|
||||||
virtual ChannelCreationFunction GetChannelCreationFunction(
|
virtual ChannelCreationFunction GetChannelCreationFunction() const;
|
||||||
const ServerDef& server_def) const;
|
|
||||||
|
|
||||||
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
|
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
|
||||||
|
|
||||||
// Creates a WorkerCacheInterface for a session.
|
// Creates a WorkerCacheInterface for a session.
|
||||||
Status WorkerCacheFactory(const ServerDef& server_def,
|
Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||||
WorkerCacheInterface** worker_cache);
|
WorkerCacheInterface** worker_cache);
|
||||||
|
|
||||||
// Parses a ServerDef into a GrpcChannelSpec.
|
// Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
|
||||||
Status ParseChannelSpec(const ServerDef& server_def,
|
Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||||
GrpcChannelSpec* channel_spec);
|
GrpcChannelSpec* channel_spec);
|
||||||
|
|
||||||
// Returns the port to which this server is bound.
|
// Returns the port to which this server is bound.
|
||||||
|
@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
|
|||||||
/* static */
|
/* static */
|
||||||
Status GrpcSession::Create(const SessionOptions& options,
|
Status GrpcSession::Create(const SessionOptions& options,
|
||||||
std::unique_ptr<GrpcSession>* out_session) {
|
std::unique_ptr<GrpcSession>* out_session) {
|
||||||
std::unique_ptr<GrpcSession> ret(new GrpcSession(options));
|
std::unique_ptr<GrpcSession> session(new GrpcSession(options));
|
||||||
std::unique_ptr<MasterInterface> master;
|
std::unique_ptr<MasterInterface> master;
|
||||||
// For testing, we enable the client to disable the use of the local
|
// For testing, we enable the client to disable the use of the local
|
||||||
// master registry, so that the RPC stack is exercised.
|
// master registry, so that the RPC stack is exercised.
|
||||||
@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
|
|||||||
options.target.substr(kSchemePrefixLength), &master_channel));
|
options.target.substr(kSchemePrefixLength), &master_channel));
|
||||||
master.reset(NewGrpcMaster(master_channel));
|
master.reset(NewGrpcMaster(master_channel));
|
||||||
}
|
}
|
||||||
ret->SetRemoteMaster(std::move(master));
|
session->SetRemoteMaster(std::move(master));
|
||||||
*out_session = std::move(ret);
|
*out_session = std::move(session);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
|
|||||||
CreateSessionRequest req;
|
CreateSessionRequest req;
|
||||||
*req.mutable_config() = options_.config;
|
*req.mutable_config() = options_.config;
|
||||||
*req.mutable_graph_def() = graph;
|
*req.mutable_graph_def() = graph;
|
||||||
|
req.set_target(options_.target);
|
||||||
ReEncodeConsts(req.mutable_graph_def());
|
ReEncodeConsts(req.mutable_graph_def());
|
||||||
CreateSessionResponse resp;
|
CreateSessionResponse resp;
|
||||||
Status s = master_->CreateSession(call_options, &req, &resp);
|
Status s = master_->CreateSession(call_options, &req, &resp);
|
||||||
|
@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
// completes, and we may decide to bound some of the request
|
// completes, and we may decide to bound some of the request
|
||||||
// types.
|
// types.
|
||||||
ENQUEUE_REQUEST(GetStatus, false);
|
ENQUEUE_REQUEST(GetStatus, false);
|
||||||
|
ENQUEUE_REQUEST(CreateWorkerSession, false);
|
||||||
ENQUEUE_REQUEST(CleanupAll, false);
|
ENQUEUE_REQUEST(CleanupAll, false);
|
||||||
ENQUEUE_REQUEST(RegisterGraph, false);
|
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||||
ENQUEUE_REQUEST(DeregisterGraph, false);
|
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||||
@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
ENQUEUE_REQUEST(GetStatus, false);
|
ENQUEUE_REQUEST(GetStatus, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CreateWorkerSessionHandler(
|
||||||
|
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
|
||||||
|
call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CreateWorkerSession, false);
|
||||||
|
}
|
||||||
|
|
||||||
void CleanupAllHandler(
|
void CleanupAllHandler(
|
||||||
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
|
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
|
||||||
Schedule([this, call]() {
|
Schedule([this, call]() {
|
||||||
@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
|
|||||||
::grpc::ByteBuffer* response,
|
::grpc::ByteBuffer* response,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
const int64 step_id = request->step_id();
|
const int64 step_id = request->step_id();
|
||||||
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
|
|
||||||
const string& key = request->rendezvous_key();
|
const string& key = request->rendezvous_key();
|
||||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||||
Rendezvous::ParsedKey parsed;
|
Rendezvous::ParsedKey parsed;
|
||||||
@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
|
|||||||
// of execution of the callback lambda body below, an RPC
|
// of execution of the callback lambda body below, an RPC
|
||||||
// cancellation should abort the rendezvous.
|
// cancellation should abort the rendezvous.
|
||||||
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
||||||
session->rendezvous_mgr->RecvLocalAsync(
|
env_->rendezvous_mgr->RecvLocalAsync(
|
||||||
step_id, parsed,
|
step_id, parsed,
|
||||||
[opts, response, done, src_dev](const Status& status,
|
[opts, response, done, src_dev](const Status& status,
|
||||||
const Rendezvous::Args& send_args,
|
const Rendezvous::Args& send_args,
|
||||||
|
@ -38,9 +38,8 @@ namespace {
|
|||||||
|
|
||||||
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||||
public:
|
public:
|
||||||
RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
|
RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
|
||||||
WorkerCacheInterface* cache, int64 step_id)
|
: BaseRemoteRendezvous(env, step_id, false) {}
|
||||||
: BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||||
@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
|||||||
private:
|
private:
|
||||||
~RpcRemoteRendezvous() override {}
|
~RpcRemoteRendezvous() override {}
|
||||||
|
|
||||||
WorkerCacheInterface* const cache_; // Not owned.
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
|
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
|
|||||||
return call_freelist;
|
return call_freelist;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A private cache that wraps worker_cache and allows reuse of
|
|
||||||
// WorkerInterface objects.
|
|
||||||
class WorkerFreeListCache : public WorkerCacheInterface {
|
|
||||||
public:
|
|
||||||
explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
|
|
||||||
|
|
||||||
~WorkerFreeListCache() {
|
|
||||||
for (auto p : workers_) {
|
|
||||||
wrapped_->ReleaseWorker(p.first, p.second.worker);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ListWorkers(std::vector<string>* workers) const override {
|
|
||||||
wrapped_->ListWorkers(workers);
|
|
||||||
}
|
|
||||||
|
|
||||||
WorkerInterface* CreateWorker(const string& target) override {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
auto p = workers_.find(target);
|
|
||||||
if (p != workers_.end()) {
|
|
||||||
return p->second.worker;
|
|
||||||
}
|
|
||||||
WorkerState state;
|
|
||||||
state.worker = wrapped_->CreateWorker(target);
|
|
||||||
if (state.worker != nullptr) {
|
|
||||||
workers_.insert(std::make_pair(target, state));
|
|
||||||
}
|
|
||||||
return state.worker;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
|
|
||||||
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetDeviceLocalityNonBlocking(const string& device,
|
|
||||||
DeviceLocality* locality) override {
|
|
||||||
return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
|
|
||||||
StatusCallback done) override {
|
|
||||||
wrapped_->GetDeviceLocalityAsync(device, locality, done);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
|
|
||||||
|
|
||||||
void ClearLogs() override { wrapped_->ClearLogs(); }
|
|
||||||
|
|
||||||
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
|
|
||||||
return wrapped_->RetrieveLogs(step_id, ss);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
WorkerCacheInterface* wrapped_;
|
|
||||||
|
|
||||||
// Information kept per created WorkerInterface.
|
|
||||||
struct WorkerState {
|
|
||||||
WorkerInterface* worker;
|
|
||||||
// TODO(jeff,sanjay): Add reference count if we support eviction.
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO(jeff,sanjay): Eviction when the map becomes too big.
|
|
||||||
mutex mu_;
|
|
||||||
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
|
|
||||||
};
|
|
||||||
|
|
||||||
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||||
DoneCallback done) {
|
DoneCallback done) {
|
||||||
|
CHECK(is_initialized());
|
||||||
Status s;
|
Status s;
|
||||||
|
|
||||||
// Prepare a RecvTensor call that can handle being aborted.
|
// Prepare a RecvTensor call that can handle being aborted.
|
||||||
@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
|||||||
s = errors::Internal(parsed.src_device,
|
s = errors::Internal(parsed.src_device,
|
||||||
" is invalid remote source device.");
|
" is invalid remote source device.");
|
||||||
}
|
}
|
||||||
WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
|
WorkerSession* sess = session();
|
||||||
|
WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
|
||||||
if (s.ok() && rwi == nullptr) {
|
if (s.ok() && rwi == nullptr) {
|
||||||
s = errors::Internal("No worker known as ", call->src_worker_);
|
s = errors::Internal("No worker known as ", call->src_worker_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Device* dst_device;
|
Device* dst_device;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||||
}
|
}
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
get_call_freelist()->Release(call, cache_);
|
if (rwi != nullptr) {
|
||||||
|
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
|
||||||
|
}
|
||||||
|
get_call_freelist()->Release(call, sess->worker_cache.get());
|
||||||
done(s, Args(), recv_args, Tensor{}, false);
|
done(s, Args(), recv_args, Tensor{}, false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
|||||||
// current status should be bad.
|
// current status should be bad.
|
||||||
Status s = call->status();
|
Status s = call->status();
|
||||||
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
|
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
|
||||||
cache_->ReleaseWorker(call->src_worker_, call->wi_);
|
session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
|
||||||
call->wi_ = nullptr;
|
call->wi_ = nullptr;
|
||||||
get_call_freelist()->Release(call, cache_);
|
get_call_freelist()->Release(call, session()->worker_cache.get());
|
||||||
Unref();
|
Unref();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env,
|
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
|
||||||
const string& worker_name,
|
: BaseRendezvousMgr(env) {}
|
||||||
WorkerCacheInterface* worker_cache)
|
|
||||||
: BaseRendezvousMgr(env, worker_name),
|
|
||||||
cache_(new WorkerFreeListCache(worker_cache)) {}
|
|
||||||
|
|
||||||
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
||||||
const WorkerEnv* worker_env,
|
const WorkerEnv* worker_env) {
|
||||||
const string& worker_name) {
|
return new RpcRemoteRendezvous(worker_env, step_id);
|
||||||
return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
|
|
||||||
step_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
|
||||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class DeviceMgr;
|
||||||
|
|
||||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||||
// until the tensor is received. Each global unique "step_id"
|
// until the tensor is received. Each global unique "step_id"
|
||||||
@ -44,17 +44,12 @@ namespace tensorflow {
|
|||||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||||
class RpcRendezvousMgr : public BaseRendezvousMgr {
|
class RpcRendezvousMgr : public BaseRendezvousMgr {
|
||||||
public:
|
public:
|
||||||
explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name,
|
explicit RpcRendezvousMgr(const WorkerEnv* env);
|
||||||
WorkerCacheInterface* worker_cache);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
|
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
|
||||||
const string& session_name) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Private cache_ that allows us to reuse WorkerInterface objects.
|
|
||||||
std::unique_ptr<WorkerCacheInterface> cache_;
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
|
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
|
|||||||
: cache_(new DummyWorkerCache),
|
: cache_(new DummyWorkerCache),
|
||||||
worker_session_("/job:mnist/replica:1/task:2",
|
worker_session_("/job:mnist/replica:1/task:2",
|
||||||
std::unique_ptr<WorkerCacheInterface>(cache_),
|
std::unique_ptr<WorkerCacheInterface>(cache_),
|
||||||
std::unique_ptr<RendezvousMgrInterface>(),
|
std::unique_ptr<DeviceMgr>(),
|
||||||
std::unique_ptr<GraphMgr>()),
|
std::unique_ptr<GraphMgr>()),
|
||||||
rmgr_(&env, worker_session_.worker_name, cache_) {
|
rmgr_(&env) {
|
||||||
env.env = Env::Default();
|
env.env = Env::Default();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
|
|||||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
{
|
{
|
||||||
Rendezvous* rendez = rmgr_.Find(step_id);
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
core::ScopedUnref unref(rendez);
|
core::ScopedUnref unref(rendez);
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||||
@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
|
|||||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
{ // Explicit Abort().
|
{ // Explicit Abort().
|
||||||
const int64 step_id = 123;
|
const int64 step_id = 123;
|
||||||
Rendezvous* rendez = rmgr_.Find(step_id);
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
core::ScopedUnref unref(rendez);
|
core::ScopedUnref unref(rendez);
|
||||||
SchedClosure([this, rendez]() {
|
SchedClosure([this, rendez]() {
|
||||||
env.env->SleepForMicroseconds(100 * 1000);
|
env.env->SleepForMicroseconds(100 * 1000);
|
||||||
@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
|
|||||||
Tensor val(DT_STRING);
|
Tensor val(DT_STRING);
|
||||||
bool val_dead = false;
|
bool val_dead = false;
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
||||||
}
|
}
|
||||||
{ // Cleanup causes Abort().
|
{ // Cleanup causes Abort().
|
||||||
const int64 step_id = 321;
|
const int64 step_id = 321;
|
||||||
Rendezvous* rendez = rmgr_.Find(step_id);
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
core::ScopedUnref unref(rendez);
|
core::ScopedUnref unref(rendez);
|
||||||
SchedClosure([this, step_id]() {
|
SchedClosure([this, step_id]() {
|
||||||
env.env->SleepForMicroseconds(100 * 1000);
|
env.env->SleepForMicroseconds(100 * 1000);
|
||||||
@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
|
|||||||
Tensor val(DT_STRING);
|
Tensor val(DT_STRING);
|
||||||
bool val_dead = false;
|
bool val_dead = false;
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
|
|||||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
{
|
{
|
||||||
const int64 step_id = 123;
|
const int64 step_id = 123;
|
||||||
Rendezvous* rendez = rmgr_.Find(step_id);
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
core::ScopedUnref unref(rendez);
|
core::ScopedUnref unref(rendez);
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||||
@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
|||||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
{
|
{
|
||||||
Rendezvous* rendez = rmgr_.Find(step_id);
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
core::ScopedUnref unref(rendez);
|
core::ScopedUnref unref(rendez);
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
args.device_context = dc;
|
args.device_context = dc;
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
|
@ -17,8 +17,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -26,23 +27,12 @@ namespace tensorflow {
|
|||||||
SessionMgr::SessionMgr(
|
SessionMgr::SessionMgr(
|
||||||
WorkerEnv* worker_env, const string& default_worker_name,
|
WorkerEnv* worker_env, const string& default_worker_name,
|
||||||
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
||||||
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
|
|
||||||
WorkerCacheFactory worker_cache_factory)
|
|
||||||
: SessionMgr(
|
|
||||||
worker_env, default_worker_name, std::move(default_worker_cache),
|
|
||||||
default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
|
|
||||||
|
|
||||||
SessionMgr::SessionMgr(
|
|
||||||
WorkerEnv* worker_env, const string& default_worker_name,
|
|
||||||
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
|
||||||
RendezvousMgrInterface* default_rendezvous_mgr,
|
|
||||||
WorkerCacheFactory worker_cache_factory)
|
WorkerCacheFactory worker_cache_factory)
|
||||||
: worker_env_(worker_env),
|
: worker_env_(worker_env),
|
||||||
legacy_session_(
|
legacy_session_(default_worker_name, std::move(default_worker_cache),
|
||||||
default_worker_name, std::move(default_worker_cache),
|
std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
|
||||||
std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr),
|
|
||||||
std::unique_ptr<GraphMgr>(
|
std::unique_ptr<GraphMgr>(
|
||||||
new GraphMgr(worker_env, default_rendezvous_mgr))),
|
new GraphMgr(worker_env, worker_env->device_mgr))),
|
||||||
worker_cache_factory_(std::move(worker_cache_factory)) {}
|
worker_cache_factory_(std::move(worker_cache_factory)) {}
|
||||||
|
|
||||||
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
|
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
|
||||||
@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
|
|||||||
Status SessionMgr::CreateSession(const string& session,
|
Status SessionMgr::CreateSession(const string& session,
|
||||||
const ServerDef& server_def) {
|
const ServerDef& server_def) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
if (session.empty()) {
|
||||||
|
return errors::InvalidArgument("Session must be non-empty.");
|
||||||
|
}
|
||||||
|
|
||||||
const string worker_name = WorkerNameFromServerDef(server_def);
|
const string worker_name = WorkerNameFromServerDef(server_def);
|
||||||
|
|
||||||
WorkerCacheInterface* worker_cache = nullptr;
|
WorkerCacheInterface* worker_cache = nullptr;
|
||||||
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
|
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
|
||||||
|
|
||||||
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
|
std::vector<Device*> renamed_devices;
|
||||||
new RpcRendezvousMgr(worker_env_, worker_name, worker_cache));
|
for (Device* d : worker_env_->local_devices) {
|
||||||
|
renamed_devices.push_back(
|
||||||
|
RenamedDevice::NewRenamedDevice(worker_name, d, false));
|
||||||
|
}
|
||||||
|
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
|
||||||
|
|
||||||
std::unique_ptr<GraphMgr> graph_mgr(
|
std::unique_ptr<GraphMgr> graph_mgr(
|
||||||
new GraphMgr(worker_env_, rendezvous_mgr.get()));
|
new GraphMgr(worker_env_, device_mgr.get()));
|
||||||
|
|
||||||
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
|
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
|
||||||
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
|
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
|
||||||
std::move(rendezvous_mgr), std::move(graph_mgr)));
|
std::move(device_mgr), std::move(graph_mgr)));
|
||||||
|
|
||||||
sessions_.insert(std::make_pair(session, std::move(worker_session)));
|
sessions_.insert(std::make_pair(session, std::move(worker_session)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
|
|||||||
if (it != sessions_.end()) {
|
if (it != sessions_.end()) {
|
||||||
sessions_.erase(it);
|
sessions_.erase(it);
|
||||||
}
|
}
|
||||||
std::set<string> graph_handles;
|
|
||||||
for (auto graph_handle_it = sessions_by_graph_handle_.begin();
|
|
||||||
graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
|
|
||||||
if (graph_handle_it->second == session) {
|
|
||||||
graph_handles.insert(graph_handle_it->first);
|
|
||||||
graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
|
|
||||||
if (graph_handle_it == sessions_by_graph_handle_.end()) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto step_id_it = graphs_by_step_id_.begin();
|
|
||||||
step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
|
|
||||||
if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
|
|
||||||
step_id_it = graphs_by_step_id_.erase(step_id_it);
|
|
||||||
if (step_id_it == graphs_by_step_id_.end()) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
|
|||||||
|
|
||||||
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
|
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
|
||||||
|
|
||||||
WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
|
|
||||||
const string& graph_handle) {
|
|
||||||
auto it = sessions_by_graph_handle_.find(graph_handle);
|
|
||||||
if (it == sessions_by_graph_handle_.end()) {
|
|
||||||
return &legacy_session_;
|
|
||||||
} else {
|
|
||||||
return WorkerSessionForSessionUnlocked(it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
|
|
||||||
const string& graph_handle) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
return WorkerSessionForGraphHandleUnlocked(graph_handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
auto it = graphs_by_step_id_.find(step_id);
|
|
||||||
if (it == graphs_by_step_id_.end()) {
|
|
||||||
return &legacy_session_;
|
|
||||||
} else {
|
|
||||||
return WorkerSessionForGraphHandleUnlocked(it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SessionMgr::AssociateGraphWithSession(const string& session,
|
|
||||||
const string& graph_handle) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
sessions_by_graph_handle_[graph_handle] = session;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
auto it = sessions_by_graph_handle_.find(graph_handle);
|
|
||||||
if (it != sessions_by_graph_handle_.end()) {
|
|
||||||
sessions_by_graph_handle_.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
|
|
||||||
const int64 step_id) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
graphs_by_step_id_[step_id] = graph_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
auto it = graphs_by_step_id_.find(step_id);
|
|
||||||
if (it != graphs_by_step_id_.end()) {
|
|
||||||
graphs_by_step_id_.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -30,6 +30,8 @@ struct WorkerEnv;
|
|||||||
|
|
||||||
// SessionMgr keeps track of information related to a given session.
|
// SessionMgr keeps track of information related to a given session.
|
||||||
//
|
//
|
||||||
|
// SessionMgr runs on the workers.
|
||||||
|
//
|
||||||
// SessionMgr is threadsafe.
|
// SessionMgr is threadsafe.
|
||||||
class SessionMgr {
|
class SessionMgr {
|
||||||
public:
|
public:
|
||||||
@ -39,7 +41,6 @@ class SessionMgr {
|
|||||||
explicit SessionMgr(
|
explicit SessionMgr(
|
||||||
WorkerEnv* worker_env, const string& default_worker_name,
|
WorkerEnv* worker_env, const string& default_worker_name,
|
||||||
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
||||||
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
|
|
||||||
WorkerCacheFactory worker_cache_factory);
|
WorkerCacheFactory worker_cache_factory);
|
||||||
~SessionMgr() {}
|
~SessionMgr() {}
|
||||||
|
|
||||||
@ -50,49 +51,36 @@ class SessionMgr {
|
|||||||
WorkerSession* WorkerSessionForSession(const string& session);
|
WorkerSession* WorkerSessionForSession(const string& session);
|
||||||
WorkerSession* LegacySession();
|
WorkerSession* LegacySession();
|
||||||
|
|
||||||
// Locates the worker session for a given graph handle
|
|
||||||
WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
|
|
||||||
void AssociateGraphWithSession(const string& session,
|
|
||||||
const string& graph_handle);
|
|
||||||
void DisassociateGraphFromSession(const string& graph_handle);
|
|
||||||
|
|
||||||
// Locates a worker session for a given step id
|
|
||||||
WorkerSession* WorkerSessionForStepId(const int64 step_id);
|
|
||||||
void AssociateStepIdWithGraph(const string& graph_handle,
|
|
||||||
const int64 step_id);
|
|
||||||
void DisassociateStepIdFromGraph(const int64 step_id);
|
|
||||||
|
|
||||||
Status DeleteSession(const string& session);
|
Status DeleteSession(const string& session);
|
||||||
|
|
||||||
static string WorkerNameFromServerDef(const ServerDef& server_def);
|
static string WorkerNameFromServerDef(const ServerDef& server_def);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Private constructor to work around std::unique_ptr ownership issues.
|
|
||||||
explicit SessionMgr(
|
|
||||||
WorkerEnv* worker_env, const string& default_worker_name,
|
|
||||||
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
|
|
||||||
RendezvousMgrInterface* default_rendezvous_mgr,
|
|
||||||
WorkerCacheFactory worker_cache_factory);
|
|
||||||
|
|
||||||
const WorkerEnv* const worker_env_; // Not owned.
|
const WorkerEnv* const worker_env_; // Not owned.
|
||||||
|
|
||||||
|
// A note about destruction:
|
||||||
|
// We must delete graph_mgr before device_mgr, due to shared
|
||||||
|
// ownership of OpKernels in the executors. (The graph_mgr will
|
||||||
|
// free all stateless OpKernels, and pass over borrowed stateful
|
||||||
|
// OpKernels, which are also held in their respective devices'
|
||||||
|
// OpSegments.)
|
||||||
|
//
|
||||||
|
// legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
|
||||||
|
// that sessions_'s WorkerSessions are deleted (which do not own the
|
||||||
|
// underlying devices, but instead own RenamedDevices) before
|
||||||
|
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
|
||||||
|
// device_mgr is deleted after WorkerSession's graph_mgr.
|
||||||
|
|
||||||
WorkerSession legacy_session_;
|
WorkerSession legacy_session_;
|
||||||
|
|
||||||
const WorkerCacheFactory worker_cache_factory_;
|
const WorkerCacheFactory worker_cache_factory_;
|
||||||
|
|
||||||
WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
|
WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
|
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// A map from session identifier to internal session structure.
|
// A map from session identifier to internal session structure.
|
||||||
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
|
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// A map from graph handles to the session that they belong to.
|
|
||||||
std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
// A map from globally-unique step id's to the corresponding graph handles.
|
|
||||||
std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
|
|||||||
SessionMgrTest()
|
SessionMgrTest()
|
||||||
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
||||||
std::unique_ptr<WorkerCacheInterface>(),
|
std::unique_ptr<WorkerCacheInterface>(),
|
||||||
std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
|
|
||||||
&env_, "/job:mnist/replica:0/task:0", nullptr)),
|
|
||||||
factory_),
|
factory_),
|
||||||
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
|
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
|
||||||
|
|
||||||
@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
|
|||||||
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
|
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
|
||||||
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
||||||
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
|
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
|
||||||
|
EXPECT_NE(mgr_.LegacySession(), session);
|
||||||
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
|
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateGraphWithSession) {
|
TEST_F(SessionMgrTest, LegacySession) {
|
||||||
ServerDef server_def;
|
ServerDef server_def;
|
||||||
string session_handle = "test_session_handle";
|
string session_handle = "";
|
||||||
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
|
|
||||||
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
||||||
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
|
EXPECT_EQ(mgr_.LegacySession(), session);
|
||||||
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
|
|
||||||
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
|
|
||||||
ASSERT_EQ(session, graph_session);
|
|
||||||
|
|
||||||
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
|
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateStepWithGraph) {
|
|
||||||
ServerDef server_def;
|
|
||||||
string session_handle = "test_session_handle";
|
|
||||||
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
|
|
||||||
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
|
||||||
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
|
|
||||||
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
|
|
||||||
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
|
|
||||||
ASSERT_EQ(session, graph_session);
|
|
||||||
|
|
||||||
int64 step_id = 1234567890L;
|
|
||||||
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
|
|
||||||
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
|
|
||||||
ASSERT_EQ(session, step_session);
|
|
||||||
ASSERT_EQ(graph_session, step_session);
|
|
||||||
|
|
||||||
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
|
|
||||||
string session_handle = "test_session_handle";
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
|
|
||||||
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
|
|
||||||
ASSERT_EQ(legacy_session_, graph_session);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
|
|
||||||
ServerDef server_def;
|
|
||||||
string session_handle = "test_session_handle";
|
|
||||||
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
|
|
||||||
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
|
|
||||||
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
|
|
||||||
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
int64 step_id = 1234567890L;
|
|
||||||
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
|
|
||||||
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
|
|
||||||
ASSERT_EQ(legacy_session_, step_session);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
|
|
||||||
string session_handle = "test_session_handle";
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
|
|
||||||
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
|
|
||||||
ASSERT_EQ(legacy_session_, graph_session);
|
|
||||||
|
|
||||||
int64 step_id = 1234567890L;
|
|
||||||
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
|
|
||||||
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
|
|
||||||
ASSERT_EQ(legacy_session_, step_session);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
|
|
||||||
string session_handle = "test_session_handle";
|
|
||||||
string graph_handle = "test_graph_handle";
|
|
||||||
int64 step_id = 1234567890L;
|
|
||||||
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
|
|
||||||
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
|
|
||||||
ASSERT_EQ(legacy_session_, step_session);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
|
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
|
||||||
ServerDef server_def;
|
ServerDef server_def;
|
||||||
server_def.set_job_name("worker");
|
server_def.set_job_name("worker");
|
||||||
|
@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
|
|||||||
Status s = session->graph_mgr->Register(
|
Status s = session->graph_mgr->Register(
|
||||||
request->session_handle(), request->graph_def(), request->graph_options(),
|
request->session_handle(), request->graph_def(), request->graph_options(),
|
||||||
request->debug_options(), response->mutable_graph_handle());
|
request->debug_options(), response->mutable_graph_handle());
|
||||||
if (s.ok()) {
|
|
||||||
env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
|
|
||||||
response->graph_handle());
|
|
||||||
}
|
|
||||||
done(s);
|
done(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
|
|||||||
DeregisterGraphResponse* response,
|
DeregisterGraphResponse* response,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
WorkerSession* session =
|
WorkerSession* session =
|
||||||
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
|
env_->session_mgr->WorkerSessionForSession(request->session_handle());
|
||||||
Status s = session->graph_mgr->Deregister(request->graph_handle());
|
Status s = session->graph_mgr->Deregister(request->graph_handle());
|
||||||
env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
|
|
||||||
|
|
||||||
done(s);
|
done(s);
|
||||||
}
|
}
|
||||||
@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Worker::AbortStep(int64 step_id) {
|
void Worker::AbortStep(int64 step_id) {
|
||||||
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
|
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
|
||||||
Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
|
|
||||||
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
|
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
|
||||||
// Delay a bit before aborting the step. This way, the root
|
// Delay a bit before aborting the step. This way, the root
|
||||||
// cause may return first back to the client instead of this
|
// cause may return first back to the client instead of this
|
||||||
@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
|||||||
const int64 step_id = request->step_id();
|
const int64 step_id = request->step_id();
|
||||||
TRACEPRINTF("RunGraph: %lld", step_id);
|
TRACEPRINTF("RunGraph: %lld", step_id);
|
||||||
WorkerSession* session =
|
WorkerSession* session =
|
||||||
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
|
env_->session_mgr->WorkerSessionForSession(request->session_handle());
|
||||||
env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
|
|
||||||
GraphMgr::NamedTensors in;
|
GraphMgr::NamedTensors in;
|
||||||
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
||||||
Status s = PrepareRunGraph(request, &in, out);
|
Status s = PrepareRunGraph(request, &in, out);
|
||||||
@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
|||||||
}
|
}
|
||||||
CostGraphDef* cost_graph = response->mutable_cost_graph();
|
CostGraphDef* cost_graph = response->mutable_cost_graph();
|
||||||
session->graph_mgr->ExecuteAsync(
|
session->graph_mgr->ExecuteAsync(
|
||||||
request->graph_handle(), step_id, request->exec_opts(), collector,
|
request->graph_handle(), step_id, session, request->exec_opts(),
|
||||||
cost_graph, cm, in,
|
collector, cost_graph, cm, in,
|
||||||
[this, step_id, response, session, cm, out, token, collector, opts,
|
[this, step_id, response, session, cm, out, token, collector, opts,
|
||||||
done](Status s) {
|
done](Status s) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
|
|||||||
const string& graph_handle = request->graph_handle();
|
const string& graph_handle = request->graph_handle();
|
||||||
TRACEPRINTF("PartialRunGraph: %lld", step_id);
|
TRACEPRINTF("PartialRunGraph: %lld", step_id);
|
||||||
WorkerSession* session =
|
WorkerSession* session =
|
||||||
env_->session_mgr->WorkerSessionForGraphHandle(graph_handle);
|
env_->session_mgr->WorkerSessionForSession(request->session_handle());
|
||||||
env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
|
|
||||||
GraphMgr::NamedTensors in;
|
GraphMgr::NamedTensors in;
|
||||||
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
||||||
Status s = PrepareRunGraph(request, &in, out);
|
Status s = PrepareRunGraph(request, &in, out);
|
||||||
@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
|
|||||||
[cm]() { cm->StartCancel(); });
|
[cm]() { cm->StartCancel(); });
|
||||||
}
|
}
|
||||||
session->graph_mgr->ExecuteAsync(
|
session->graph_mgr->ExecuteAsync(
|
||||||
graph_handle, step_id, request->exec_opts(), nullptr /* collector */,
|
graph_handle, step_id, session, request->exec_opts(),
|
||||||
nullptr /* cost_graph */, cm, in,
|
nullptr /* collector */, nullptr /* cost_graph */, cm, in,
|
||||||
[this, token, graph_handle, step_id, cm](Status s) {
|
[this, token, graph_handle, step_id, cm](Status s) {
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
|
|||||||
CleanupGraphResponse* response,
|
CleanupGraphResponse* response,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
const int64 step_id = request->step_id();
|
const int64 step_id = request->step_id();
|
||||||
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
|
env_->rendezvous_mgr->Cleanup(step_id);
|
||||||
session->rendezvous_mgr->Cleanup(step_id);
|
|
||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
|
|||||||
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||||
Device** src_dev) {
|
Device** src_dev) {
|
||||||
// Figures out which device the tensor is hosted on.
|
// Figures out which device the tensor is hosted on.
|
||||||
TF_RETURN_IF_ERROR(
|
string local_name = DeviceNameUtils::LocalName(parsed.src_device);
|
||||||
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
|
TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
|
||||||
|
|
||||||
// Does the device have the right incarnation number we expect?
|
// Does the device have the right incarnation number we expect?
|
||||||
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
|
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -24,8 +25,10 @@ namespace thread {
|
|||||||
class ThreadPool;
|
class ThreadPool;
|
||||||
} // namespace thread
|
} // namespace thread
|
||||||
|
|
||||||
|
class Device;
|
||||||
class DeviceMgr;
|
class DeviceMgr;
|
||||||
class Env;
|
class Env;
|
||||||
|
class RendezvousMgrInterface;
|
||||||
class SessionMgr;
|
class SessionMgr;
|
||||||
|
|
||||||
// The worker environment class, which holds a bag of pointers to
|
// The worker environment class, which holds a bag of pointers to
|
||||||
@ -38,10 +41,18 @@ struct WorkerEnv {
|
|||||||
// session_mgr encapsulates state for each session.
|
// session_mgr encapsulates state for each session.
|
||||||
SessionMgr* session_mgr = nullptr;
|
SessionMgr* session_mgr = nullptr;
|
||||||
|
|
||||||
|
// The local devices of this worker. Devices are owned by the device_mgr.
|
||||||
|
//
|
||||||
|
// REQUIRES: !local_devices.empty().
|
||||||
|
std::vector<Device*> local_devices;
|
||||||
|
|
||||||
// device_mgr manages local devices (cpu and gpu). The WorkerService
|
// device_mgr manages local devices (cpu and gpu). The WorkerService
|
||||||
// is the network interface for managed devices.
|
// is the network interface for managed devices.
|
||||||
DeviceMgr* device_mgr = nullptr;
|
DeviceMgr* device_mgr = nullptr;
|
||||||
|
|
||||||
|
// A set of rendezvous keyed by step ids.
|
||||||
|
RendezvousMgrInterface* rendezvous_mgr = nullptr;
|
||||||
|
|
||||||
// A pool of threads for scheduling compute work.
|
// A pool of threads for scheduling compute work.
|
||||||
thread::ThreadPool* compute_pool = nullptr;
|
thread::ThreadPool* compute_pool = nullptr;
|
||||||
};
|
};
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user