diff --git a/configure b/configure index 4104651cbbb..4e66e952c2b 100755 --- a/configure +++ b/configure @@ -385,7 +385,7 @@ fi # Append CC optimization flags to bazel.rc for opt in $CC_OPT_FLAGS; do - write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt' + write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt" done # Run the gen_git_source to create links where bazel can track dependencies for diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 4ad69ae3fbd..3ab4e8efcdb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ tf_cuda_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 8810b8731ae..8d4260a0b9c 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -91,6 +91,7 @@ cc_library( deps = [ ":array_grad", ":math_grad", + ":nn_grad", ], ) diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 93f487c36ca..5e336c5287b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d86e741b69e..362a1018955 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, options, Device::BuildDeviceAttributes( "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type())), - cpu_allocator()), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 87ceb43d1fe..6af69eeec12 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -668,6 +668,14 @@ class ComputationBuilder { // then Build() should be used instead. Computation BuildAndNoteError(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // ComputationDataHandle and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + private: using PopulateLiteral = std::function; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1c704fd1ee7..1e34de9e4bd 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name, if (&argument == retval) { continue; } - compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + compute_function_->addAttribute(argument.getArgNo() + 1, + llvm::Attribute::NoAlias); } ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 04babcca0c8..e52e55a1a81 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ir_emitter_context_->buffer_assignment().GetTempAllocation()) { kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size()); } - kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5e7bd4a7ce8..d413621cfe2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -705,7 +705,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape()) && ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && + if ((instruction->IsElementwiseOnOperand(operand_no) || + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) && !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape())) { diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 61dc7b12075..4f586c334dc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface { return Status::OK(); } + // This method can be overriden to mark instructions as requiring the operands + // to have the same layout as the result, for performance or correctness. This + // will propagate constraints through the instruction from the result into the + // operands. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + return false; + } + // Construct contraints and assign layouts to all instructions in the // computation satisfying the given ComputationLayout. Layouts constraints are // added, then propagated until all LogicalBuffers in the computation are diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 338d63f1a00..b2ef8ed486b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -244,8 +244,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "cannot concatenate arrays with different ranks: %lld vs %lld", - ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "(%s)", + ShapeUtil::Rank(*arg_shape), + ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), + ShapeUtil::HumanString(*shape).c_str()); } if (arg_shape->element_type() != shape->element_type()) { return InvalidArgument( diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 6fd1ae08149..560e45fc135 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -118,6 +118,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/types.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" + "tensorflow/core/protobuf/cluster.proto" "tensorflow/core/protobuf/config.proto" "tensorflow/core/protobuf/debug.proto" "tensorflow/core/protobuf/rewriter_config.proto" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 2a19433a7b2..eae00ab8756 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -22,6 +22,7 @@ set(tf_op_lib_names "image_ops" "io_ops" "linalg_ops" + "lookup_ops" "logging_ops" "math_ops" "nn_ops" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 53ebfbb57de..9e2eb71b4c2 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator") add_python_module("tensorflow/python/estimator/export") add_python_module("tensorflow/python/estimator/inputs") add_python_module("tensorflow/python/estimator/inputs/queues") +add_python_module("tensorflow/python/feature_column") add_python_module("tensorflow/python/framework") add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/kernel_tests") @@ -596,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("logging_ops") +GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 9f675c66135..0c818dee031 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -710,25 +710,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "identity_test", - size = "small", - srcs = ["python/kernel_tests/bijectors/identity_test.py"], - additional_deps = [ - ":bijectors_py", - ":distributions_py", - "//third_party/py/numpy", - "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - cuda_py_test( name = "inline_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 6ea74fab0e4..ea12e13010a 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -25,6 +25,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.chi2 import * +from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform @@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * -from tensorflow.contrib.distributions.python.ops.transformed_distribution import * from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.python.ops.distributions.bernoulli import * from tensorflow.python.ops.distributions.beta import * from tensorflow.python.ops.distributions.categorical import * -from tensorflow.python.ops.distributions.conditional_distribution import * from tensorflow.python.ops.distributions.dirichlet import * from tensorflow.python.ops.distributions.dirichlet_multinomial import * from tensorflow.python.ops.distributions.distribution import * @@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import * from tensorflow.python.ops.distributions.multinomial import * from tensorflow.python.ops.distributions.normal import * from tensorflow.python.ops.distributions.student_t import * +from tensorflow.python.ops.distributions.transformed_distribution import * from tensorflow.python.ops.distributions.uniform import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 13554f76642..e8fd6aa2f73 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -23,9 +23,9 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 994e21dd487..20e75430844 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,12 +20,12 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index a4688829f1f..0ff35304283 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index c30ce60cacc..9970c0b4d86 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index a4688829f1f..0ff35304283 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index b30a3b599bb..de1659aa9f4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index 6f1a6b1cf4b..e4f9d72785c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import numpy as np from scipy import special -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 173d52686d6..62e3869db09 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index 214b196b547..d9af9aec50d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test rng = np.random.RandomState(42) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index e1d31e373cc..1684a5fffe1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * -from tensorflow.contrib.distributions.python.ops.bijectors.identity import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * @@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.python.ops.distributions.bijector import * +from tensorflow.python.ops.distributions.identity_bijector import Identity # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py b/tensorflow/contrib/distributions/python/ops/bijectors/identity.py deleted file mode 100644 index 749dd268f98..00000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Identity bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ["Identity"] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/distributions/conditional_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_distribution.py similarity index 100% rename from tensorflow/python/ops/distributions/conditional_distribution.py rename to tensorflow/contrib/distributions/python/ops/conditional_distribution.py diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index b0967802bd8..2e1e68cf058 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -17,9 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import conditional_distribution +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index a66eb1674ca..fbd623ed3a1 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 581e190f73b..5b57a95c55e 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic -from tensorflow.contrib.distributions.python.ops import transformed_distribution # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid @@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 00415f5e1aa..da1cd72a6f1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 299ff36962e..ae804b61727 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index aba8eabe10c..fe661a56250 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -108,6 +108,7 @@ tf_custom_op_py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/feature_column", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index d6d5bf2294f..04fe2370d1d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_py +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1497,9 +1499,12 @@ def _real_valued_var_len_column(column_name, is_sparse) -class _RealValuedColumn(_FeatureColumn, collections.namedtuple( - "_RealValuedColumn", - ["column_name", "dimension", "default_value", "dtype", "normalizer"])): +class _RealValuedColumn( + _FeatureColumn, + fc_core._DenseColumn, # pylint: disable=protected-access + collections.namedtuple( + "_RealValuedColumn", + ["column_name", "dimension", "default_value", "dtype", "normalizer"])): """Represents a real valued feature column also known as continuous features. Instances of this class are immutable. The dictionary returned by InputBuilder @@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple( def _to_dense_tensor(self, input_tensor): return input_tensor + @property + def _variable_shape(self): + return tensor_shape.TensorShape((self.dimension)) + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + return inputs.get(self) + + def _transform_feature(self, inputs): + return math_ops.to_float( + self._normalized_input_tensor(inputs.get(self.name))) + + @property + def _parse_example_config(self): + return self.config + def real_valued_column(column_name, dimension=1, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 632836fee44..a09cc53571b 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,14 +27,15 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -223,7 +224,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(keys_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[keys_sparse].indices.eval(), @@ -241,7 +242,7 @@ class TransformerTest(test.TestCase): output = feature_column_ops._Transformer(features).transform(keys_sparse) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.dtype, dtypes.int64) @@ -310,7 +311,7 @@ class TransformerTest(test.TestCase): self.assertIn(weighted_ids, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), ids_tensor.dense_shape.eval()) self.assertAllEqual(output[weighted_ids][0].indices.eval(), @@ -340,7 +341,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -362,7 +363,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -386,7 +387,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -408,7 +409,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -600,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): one_hot_column, embedding_column, real_valued_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) def testRealValuedColumn(self): @@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensions(self): real_valued = feature_column.real_valued_column("price", 2) @@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnSparse(self): sparse_real_valued = feature_column._real_valued_var_len_column( @@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensionsAndNormalizer(self): real_valued = feature_column.real_valued_column( @@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testBucketizedColumnWithNormalizerSucceedsForDNN(self): bucket = feature_column.bucketized_column( @@ -697,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], output.eval()) @@ -715,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], output.eval()) @@ -733,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], output.eval()) @@ -767,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([3, 10], output.eval().shape) def testEmbeddingColumnSucceedsForDNN(self): @@ -874,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self): @@ -897,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self): @@ -948,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: ids_weighted_by_weights"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() feature_column_ops.input_from_feature_columns(features, [weighted_ids]) def testCrossedColumnFailsForDNN(self): @@ -1055,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: (sum of weights) self.assertAllEqual(output.eval(), [[10.], [50.], [0.]]) @@ -1293,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, 4]) @@ -1327,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, hash_buckets]) @@ -1357,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1386,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1416,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): embedding_weights) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) expected_input_shape = [4, 3, embedding_dimension] @@ -1483,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = [ @@ -1564,7 +1581,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_ids], num_outputs=5) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testWeightedSparseColumnWithDenseInputTensor(self): @@ -1580,7 +1597,7 @@ class WeightedSumTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testCrossedColumn(self): @@ -1634,7 +1651,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.initialize_all_variables().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (3, 1)) @@ -1709,7 +1726,7 @@ class WeightedSumTest(test.TestCase): features, [age, language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1749,7 +1766,7 @@ class WeightedSumTest(test.TestCase): self.assertEqual(len(variables), 1) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1813,7 +1830,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1841,7 +1858,7 @@ class WeightedSumTest(test.TestCase): features, [language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: 0.1 + language_weight['hindi'] + language_weight['english'] sess.run(bias.assign([0.1])) @@ -1864,7 +1881,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (15, 1)) @@ -1898,7 +1915,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1922,7 +1939,7 @@ class WeightedSumTest(test.TestCase): features, [language_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[language_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1955,7 +1972,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1996,7 +2013,7 @@ class WeightedSumTest(test.TestCase): scope=scope)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(2, len(column_to_variable[country])) self.assertEqual(3, len(column_to_variable[language])) @@ -2033,7 +2050,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, incomes], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() incomes_weights = column_to_variable[incomes][0] sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]])) @@ -2069,7 +2086,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, height, incomes], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() height_weights = column_to_variable[height][0] sess.run( @@ -2099,7 +2116,7 @@ class WeightedSumTest(test.TestCase): features, [bucket], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3], [0.4]])) @@ -2127,7 +2144,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 1 sess.run(column_to_variable[bucket][0].assign( @@ -2156,7 +2173,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 5 sess.run(column_to_variable[bucket][0].assign( @@ -2192,7 +2209,7 @@ class WeightedSumTest(test.TestCase): features, [country_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2231,7 +2248,7 @@ class WeightedSumTest(test.TestCase): features, [country_language_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2255,7 +2272,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2270,7 +2287,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2285,7 +2302,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.6], [0.7]]) @@ -2306,7 +2323,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2318,7 +2335,7 @@ class WeightedSumTest(test.TestCase): features, [feature_column.real_valued_column("age")], num_outputs=3) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(bias.assign([0.1, 0.2, 0.3])) self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) @@ -2332,7 +2349,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (1, 3)) sess.run(weights.assign([[0.01, 0.03, 0.05]])) @@ -2356,7 +2373,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) sess.run( @@ -2382,7 +2399,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2422,7 +2439,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2451,7 +2468,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2516,7 +2533,7 @@ class ParseExampleTest(test.TestCase): self.assertIn(bucket, output) self.assertIn(wire_cast, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) self.assertAllEqual(output[wire_cast].values.eval(), [2, 0]) diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 9fb9a3e2571..1926cbe7b31 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -46,7 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32): Args: uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - @{set_random_seed} for behavior. + @{tf.set_random_seed} for behavior. dtype: The data type. Only floating point types are supported. Returns: @@ -96,7 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - @{set_random_seed} for behavior. + @{tf.set_random_seed} for behavior. dtype: The data type. Only floating point types are supported. Returns: diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 61a6168a9eb..6fc028ab706 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase): self.context_feature_columns) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) sequence_input_val = sess.run(sequence_input) expected_shape = np.array([ 3, # expected batch size @@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase): # Obtain values of activations and final state. with session.Session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) activations, final_state = sess.run([activations_t, final_state_t]) expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 74a6da20d4e..36f843ba8e7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -57,7 +57,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator): init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 207a189a94d..d5777088de7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test @@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) @@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) @@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) @@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index e7470a544f0..69469b577dd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) features_val = sess.run(features_by_time) self.assertAllEqual(expected, features_val) @@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) actual_sequence, actual_context = sess.run( [sequence, context]) assert_equal(expected_sequence, actual_sequence) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 602d33e5f9b..85d45aef7ac 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -647,6 +647,10 @@ class Experiment(object): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") + # Estimator in core cannot work with monitors. We need to convert them + # to hooks. For Estimator in contrib, it is converted internally. So, it is + # safe to convert for both cases. + hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: return self._estimator.train(input_fn=input_fn, steps=steps, diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 4b5f3a195ce..9ecfc732998 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -24,7 +24,6 @@ import time from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import experiment -from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import run_config from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib @@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase): self.assertEqual(1, est.eval_count) self.assertEqual(1, len(est.monitors)) self.assertEqual([noop_hook], est.eval_hooks) - self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) + self.assertTrue(isinstance(est.monitors[0], + session_run_hook.SessionRunHook)) def test_train_hooks_extend_does_not_mutate_input_hooks(self): for est in self._estimators_for_tests(): @@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase): self.assertEqual(1, est.export_count) self.assertEqual(1, len(est.monitors)) self.assertEqual([noop_hook], est.eval_hooks) - self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) + self.assertTrue(isinstance(est.monitors[0], + session_run_hook.SessionRunHook)) def test_train_and_evaluate_with_no_eval_during_training(self): for est in self._estimators_for_tests(): diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 4b7867f2d00..98365c05f66 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -429,11 +429,14 @@ def _get_ready_op(): def _get_local_init_op(): + """Returns the local init ops to initialize tables and local variables.""" local_init_op = _get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) @@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) - session.run(data_flow_ops.tables_initializer()) + session.run(lookup_ops.tables_initializer()) coord = coordinator.Coordinator() threads = None try: diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index b53be292830..36a1f5f60cd 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver @@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() - data_flow_ops.tables_initializer() + lookup_ops.tables_initializer() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) - export.init(init_op=control_flow_ops.group( - variables.local_variables_initializer(), - data_flow_ops.tables_initializer()), - default_graph_signature=default_graph_signature, - named_graph_signatures=named_graph_signatures, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS)) + export.init( + init_op=control_flow_ops.group( + variables.local_variables_initializer(), + lookup_ops.tables_initializer()), + default_graph_signature=default_graph_signature, + named_graph_signatures=named_graph_signatures, + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index b3316ee8c4f..bbbd3403526 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -13,19 +13,10 @@ py_library( name = "lookup_py", srcs = [ "__init__.py", - "lookup_ops.py", ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops_gen", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", + "//tensorflow/python/feature_column:lookup_ops", ], ) @@ -39,11 +30,11 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py index dbd64cf0421..a5fcdc7b42d 100644 --- a/tensorflow/contrib/lookup/__init__.py +++ b/tensorflow/contrib/lookup/__init__.py @@ -47,7 +47,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.lookup.lookup_ops import * +from tensorflow.python.feature_column.lookup_ops import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 0ec40a63f26..5ec169b6db4 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver @@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase): table3 = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(3, table1.size().eval()) self.assertAllEqual(3, table2.size().eval()) self.assertAllEqual(3, table3.size().eval()) @@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_file(self): @@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_file(self): @@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_file_with_default_value(self): @@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_file_with_oov_buckets(self): @@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( ( 1, # From vocabulary file. @@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) @@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) @@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_tensor_with_tensor_init(self): @@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): @@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_tensor_with_default_value(self): @@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_mapping(self): @@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase): self.assertRaises(errors_impl.OpError, ids.eval) with self.assertRaisesRegexp( errors_impl.OpError, "keys and values cannot be empty"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): @@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase): indices = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), indices.eval()) @@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase): _ = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_string_to_index_with_default_value(self): default_value = -42 @@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase): feats, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), indices.eval()) @@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", default_value, default_value), features.eval()) @@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - init = data_flow_ops.tables_initializer() + init = lookup_ops.tables_initializer() self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", init.run) @@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) @@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): mapping=mapping_strings) indices = constant_op.constant([0, 1, 4], dtypes.int64) features = table.lookup(indices) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) def test_index_to_string_with_default_value(self): @@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase): feats = lookup.index_to_string(indices, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), feats.eval()) @@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_index_to_string_with_default_value(self): default_value = b"NONE" @@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase): indices, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) @@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value, shared_name=shared_name) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec((1, 2)), name="table2") - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant( ["fruit", "brain", "salad", "surgery", "UNK"]) @@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value2), oov_buckets) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index c0969e6dee2..2f1fcb149e1 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/meta_graph.pb.cc +tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/debug.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 132b4775962..6087a45168d 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/meta_graph.pb.h +tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/rewriter_config.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f1da05e4c6e..c39257ffa91 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -1,6 +1,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc +tensorflow/core/protobuf/cluster.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2a78ea61016..5eadf5d55b6 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/meta_graph.proto +tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/rewriter_config.proto diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d57203c042d..727cdd9597a 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, name=name_scope) +def sparse_recall_at_top_k(labels, + top_k_predictions, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + If `class_id` is specified, we calculate recall by considering only the + entries in the batch for which `class_id` is in the label, and computing + the fraction of them for which `class_id` is in the top-k `predictions`. + If `class_id` is not specified, we'll calculate recall as how often on + average a class among the labels of a batch entry is in the top-k + `predictions`. + + `sparse_recall_at_top_k` creates two local variables, `true_positive_at_` + and `false_negative_at_`, that are used to compute the recall_at_k + frequency. This frequency is ultimately returned as `recall_at_`: an + idempotent operation that simply divides `true_positive_at_` by total + (`true_positive_at_` + `false_negative_at_`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `recall_at_`. Set operations applied to `top_k` and `labels` calculate the + true positives and false negatives weighted by `weights`. Then `update_op` + increments `true_positive_at_` and `false_negative_at_` using these + values. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of + target classes for the associated prediction. Commonly, N=1 and `labels` + has shape [batch_size, num_labels]. [D1, ... DN] must match + `top_k_predictions`. Values should be in range [0, num_classes), where + num_classes is the last dimension of `predictions`. Values outside this + range always count towards `false_negative_at_`. + top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where + N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. + The final dimension contains the indices of top-k labels. [D1, ... DN] + must match `labels`. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + default_name = _at_k_name('recall', class_id=class_id) + with ops.name_scope(name, default_name, (top_k_predictions, labels, + weights)) as name_scope: + return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access + labels=labels, + predictions_idx=top_k_predictions, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name_scope) + + def streaming_sparse_average_precision_at_k(predictions, labels, k, @@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights): __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', 'streaming_false_negatives', @@ -2310,7 +2392,9 @@ __all__ = [ 'streaming_root_mean_squared_error', 'streaming_sensitivity_at_specificity', 'streaming_sparse_average_precision_at_k', + 'streaming_sparse_average_precision_at_top_k', 'streaming_sparse_precision_at_k', + 'streaming_sparse_precision_at_top_k', 'streaming_sparse_recall_at_k', 'streaming_specificity_at_sensitivity', 'streaming_true_negatives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b960e1310ec..f42e974e238 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase): self.assertEqual(expected, update.eval()) self.assertEqual(expected, metric.eval()) + def _test_sparse_recall_at_top_k(self, + labels, + top_k_predictions, + expected, + class_id=None, + weights=None): + with ops.Graph().as_default() as g, self.test_session(g): + if weights is not None: + weights = constant_op.constant(weights, dtypes_lib.float32) + metric, update = metric_ops.sparse_recall_at_top_k( + labels=labels, + top_k_predictions=constant_op.constant(top_k_predictions, + dtypes_lib.int32), + class_id=class_id, + weights=weights) + + # Fails without initialized vars. + self.assertRaises(errors_impl.OpError, metric.eval) + self.assertRaises(errors_impl.OpError, update.eval) + variables.variables_initializer(variables.local_variables()).run() + + # Run per-step op and assert expected values. + if math.isnan(expected): + self.assertTrue(math.isnan(update.eval())) + self.assertTrue(math.isnan(metric.eval())) + else: + self.assertEqual(expected, update.eval()) + self.assertEqual(expected, metric.eval()) + def test_one_label_at_k1_nan(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (-1, 0, 1, 4): self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_one_label_at_k1_no_predictions(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 0 predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=2) def test_one_label_at_k1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, class_id=3) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=3) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2) def test_one_label_at_k1_weighted(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2, class_id=3, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=3.0 / 3, class_id=3, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.3 / 0.3, class_id=3, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.6 / 0.6, class_id=3, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_three_labels_at_k5_no_predictions(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase): # Class 8: 1 label, no predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=8) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=8) def test_three_labels_at_k5(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=7) # All classes: 6 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=3.0 / 6) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 6) def test_three_labels_at_k5_some_out_of_range(self): """Tests that labels outside the [0, n_classes) count in denominator.""" predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sp_labels = sparse_tensor.SparseTensorValue( indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1, 3]], @@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=2.0 / 2, + class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=1.0 / 1, + class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=0.0 / 1, + class_id=7) # All classes: 8 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8) + self._test_sparse_recall_at_top_k( + sp_labels, top_k_predictions, expected=3.0 / 8) def test_3d_nan(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) @@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_3d_no_predictions(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (1, 8): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=class_id) def test_3d(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 4 labels, all correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=4.0 / 4, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=4.0 / 4, class_id=2) # Class 5: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=5) # Class 7: 2 labels, 1 incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 2, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, class_id=7) # All classes: 12 labels, 7 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=7.0 / 12) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=7.0 / 12) def test_3d_ignore_all(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0, 0], [0, 0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]]) def test_3d_ignore_some(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[1], [0]]) # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( @@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[0], [1]]) # Class 7: 1 label, correct. self._test_streaming_sparse_recall_at_k( @@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1.0, class_id=7, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1.0, + class_id=7, + weights=[[0], [1]]) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.0 / 1.0, class_id=7, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.0 / 1.0, + class_id=7, + weights=[[1], [0]]) # Class 7: 2 labels, 1 correct. self._test_streaming_sparse_recall_at_k( @@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 2.0, class_id=7, weights=[[1, 0], [1, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 2.0, + class_id=7, + weights=[[1, 0], [1, 0]]) # Class 7: No labels. self._test_streaming_sparse_recall_at_k( @@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=7, weights=[[0, 1], [0, 1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=7, + weights=[[0, 1], [0, 1]]) def test_sparse_tensor_value(self): predictions = [[0.1, 0.3, 0.2, 0.4], diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index ab443eab6f6..9d67563eddd 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -304,6 +304,7 @@ filegroup( exclude = [ "**/METADATA", "**/OWNERS", + "tools/**", ], ), visibility = ["//tensorflow:__subpackages__"], @@ -351,3 +352,27 @@ tf_kernel_library( "//third_party/eigen3", ], ) + +py_binary( + name = "checkpoint_convert", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "checkpoint_convert_test", + size = "small", + srcs = ["python/tools/checkpoint_convert_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":checkpoint_convert", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 15afac98237..f4589e3d9e1 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) - g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m) + cell = core_rnn_cell_impl.BasicRNNCell(2) + g, _ = cell(x, m) + self.assertEqual( + ["root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME], + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + self.assertEqual(res[0].shape, (1, 2)) + + def testBasicRNNCellNotTrainable(self): + with self.test_session() as sess: + def not_trainable_getter(getter, *args, **kwargs): + kwargs["trainable"] = False + return getter(*args, **kwargs) + + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5), + custom_getter=not_trainable_getter): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = core_rnn_cell_impl.BasicRNNCell(2) + g, _ = cell(x, m) + self.assertFalse(cell.trainable_variables) + self.assertEqual( + ["root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME], + [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g], {x.name: np.array([[1., 1.]]), @@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 8]) - g, out_m = core_rnn_cell_impl.MultiRNNCell( + cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.BasicLSTMCell( 2, state_is_tuple=False) for _ in range(2)], - state_is_tuple=False)(x, m) + state_is_tuple=False) + g, out_m = cell(x, m) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME] + self.assertEqual( + expected_variable_names, [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g, out_m], @@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase): m.name: 0.1 * np.ones([1, 8])}) self.assertEqual(len(res), 2) variables = variables_lib.global_variables() - self.assertEqual(4, len(variables)) - self.assertEquals(variables[0].op.name, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights") - self.assertEquals(variables[1].op.name, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases") - self.assertEquals(variables[2].op.name, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights") - self.assertEquals(variables[3].op.name, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases") + self.assertEqual(expected_variable_names, [v.name for v in variables]) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) expected_mem = np.array([[ diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 884b51926eb..eba2c0d2acb 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -27,7 +27,6 @@ from __future__ import division from __future__ import print_function import collections -import contextlib import hashlib import math import numbers @@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases" _WEIGHTS_VARIABLE_NAME = "weights" -@contextlib.contextmanager -def _checked_scope(cell, scope, reuse=None, **kwargs): - if reuse is not None: - kwargs["reuse"] = reuse - with vs.variable_scope(scope, **kwargs) as checking_scope: - scope_name = checking_scope.name - if hasattr(cell, "_scope"): - cell_scope = cell._scope # pylint: disable=protected-access - if cell_scope.name != checking_scope.name: - raise ValueError( - "Attempt to reuse RNNCell %s with a different variable scope than " - "its first use. First use of cell was with scope '%s', this " - "attempt is with scope '%s'. Please create a new instance of the " - "cell if you would like it to use a different set of weights. " - "If before you were using: MultiRNNCell([%s(...)] * num_layers), " - "change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). " - "If before you were using the same cell instance as both the " - "forward and reverse cell of a bidirectional RNN, simply create " - "two instances (one for forward, one for reverse). " - "In May 2017, we will start transitioning this cell's behavior " - "to use existing stored weights, if any, when it is called " - "with scope=None (which can lead to silent model degradation, so " - "this error will remain until then.)" - % (cell, cell_scope.name, scope_name, type(cell).__name__, - type(cell).__name__)) - else: - weights_found = False - try: - with vs.variable_scope(checking_scope, reuse=True): - vs.get_variable(_WEIGHTS_VARIABLE_NAME) - weights_found = True - except ValueError: - pass - if weights_found and reuse is None: - raise ValueError( - "Attempt to have a second RNNCell use the weights of a variable " - "scope that already has weights: '%s'; and the cell was not " - "constructed as %s(..., reuse=True). " - "To share the weights of an RNNCell, simply " - "reuse it in your second calculation, or create a new one with " - "the argument reuse=True." % (scope_name, type(cell).__name__)) - - # Everything is OK. Update the cell's scope and yield it. - cell._scope = checking_scope # pylint: disable=protected-access - yield checking_scope - - class BasicRNNCell(RNNCell): """The most basic RNN cell.""" diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index df36dd2bf9b..9672b8b85f0 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access - - def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py new file mode 100644 index 00000000000..1e29114b0cc --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py @@ -0,0 +1,231 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Convert checkpoints using RNNCells to new name convention. + +Usage: + + python checkpoint_convert [--write_v1_checkpoint] \ + '/path/to/checkpoint' '/path/to/new_checkpoint' +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import re +import sys + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + +_RNN_NAME_REPLACEMENTS = collections.OrderedDict([ + ############################################################################ + # contrib/rnn/python/ops/core_rnn_cell_impl.py + # BasicRNNCell + ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'), + ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'), + # GRUCell + ('gru_cell/weights', 'gru_cell/kernel'), + ('gru_cell/biases', 'gru_cell/bias'), + ('gru_cell/gates/weights', 'gru_cell/gates/kernel'), + ('gru_cell/gates/biases', 'gru_cell/gates/bias'), + ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'), + ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'), + # BasicLSTMCell + ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'), + ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'), + # LSTMCell + ('lstm_cell/weights', 'lstm_cell/kernel'), + ('lstm_cell/biases', 'lstm_cell/bias'), + ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'), + ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'), + # OutputProjectionWrapper + ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'), + ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'), + # InputProjectionWrapper + ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'), + ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/lstm_ops.py + # LSTMBlockFusedCell ?? + ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'), + ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/rnn_cell.py + # LayerNormBasicLSTMCell + ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'), + ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'), + # UGRNNCell, not found in g3, but still need it? + ('ugrnn_cell/weights', 'ugrnn_cell/kernel'), + ('ugrnn_cell/biases', 'ugrnn_cell/bias'), + # NASCell + ('nas_rnn/weights', 'nas_rnn/kernel'), + ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'), + # IntersectionRNNCell + ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'), + ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'), + ('intersection_rnn_cell/in_projection/weights', + 'intersection_rnn_cell/in_projection/kernel'), + ('intersection_rnn_cell/in_projection/biases', + 'intersection_rnn_cell/in_projection/bias'), + # PhasedLSTMCell + ('phased_lstm_cell/mask_gates/weights', + 'phased_lstm_cell/mask_gates/kernel'), + ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'), + ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'), + ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'), + ('phased_lstm_cell/output_gate/weights', + 'phased_lstm_cell/output_gate/kernel'), + ('phased_lstm_cell/output_gate/biases', + 'phased_lstm_cell/output_gate/bias'), + # AttentionCellWrapper + ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'), + ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'), + ('attention_cell_wrapper/attn_output_projection/weights', + 'attention_cell_wrapper/attn_output_projection/kernel'), + ('attention_cell_wrapper/attn_output_projection/biases', + 'attention_cell_wrapper/attn_output_projection/bias'), + ('attention_cell_wrapper/attention/weights', + 'attention_cell_wrapper/attention/kernel'), + ('attention_cell_wrapper/attention/biases', + 'attention_cell_wrapper/attention/bias'), +]) + +_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([ + ('LSTMCell/W_', 'lstm_cell/weights/part_'), + ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'), + ('GRUCell/W_', 'gru_cell/weights/part_'), + ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'), +]) + + +def _rnn_name_replacement(var_name): + for pattern in _RNN_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + break + return var_name + + +def _rnn_name_replacement_sharded(var_name): + for pattern in _RNN_SHARDED_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, + _RNN_SHARDED_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + return var_name + + +def _split_sharded_vars(name_shape_map): + """Split shareded variables. + + Args: + name_shape_map: A dict from variable name to variable shape. + + Returns: + not_sharded: Names of the non-sharded variables. + sharded: Names of the sharded varibales. + """ + sharded = [] + not_sharded = [] + for name in name_shape_map: + if re.match(name, '_[0-9]+$'): + if re.sub('_[0-9]+$', '_1', name) in name_shape_map: + sharded.append(name) + else: + not_sharded.append(name) + else: + not_sharded.append(name) + return not_sharded, sharded + + +def convert_names(checkpoint_from_path, + checkpoint_to_path, + write_v1_checkpoint=False): + """Migrates the names of variables within a checkpoint. + + Args: + checkpoint_from_path: Path to source checkpoint to be read in. + checkpoint_to_path: Path to checkpoint to be written out. + write_v1_checkpoint: Whether the output checkpoint will be in V1 format. + + Returns: + A dictionary that maps the new variable names to the Variable objects. + A dictionary that maps the old variable names to the new variable names. + """ + with ops.Graph().as_default(): + logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path) + reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path) + name_shape_map = reader.get_variable_to_shape_map() + not_sharded, sharded = _split_sharded_vars(name_shape_map) + new_variable_map = {} + conversion_map = {} + for var_name in not_sharded: + new_var_name = _rnn_name_replacement(var_name) + tensor = reader.get_tensor(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + for var_name in sharded: + new_var_name = _rnn_name_replacement_sharded(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + + write_version = (saver_pb2.SaverDef.V1 + if write_v1_checkpoint else saver_pb2.SaverDef.V2) + saver = saver_lib.Saver(new_variable_map, write_version=write_version) + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path) + saver.save(sess, checkpoint_to_path) + + logging.info('Summary:') + logging.info(' Converted %d variable name(s).' % len(new_variable_map)) + return new_variable_map, conversion_map + + +def main(_): + convert_names( + FLAGS.checkpoint_from_path, + FLAGS.checkpoint_to_path, + write_v1_checkpoint=FLAGS.write_v1_checkpoint) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument('checkpoint_from_path', type=str, + help='Path to source checkpoint to be read in.') + parser.add_argument('checkpoint_to_path', type=str, + help='Path to checkpoint to be written out.') + parser.add_argument('--write_v1_checkpoint', action='store_true', + help='Write v1 checkpoint') + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py new file mode 100644 index 00000000000..e2fc2fa80ea --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for checkpoint converter.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os +import tempfile + +from tensorflow.contrib.rnn.python.tools import checkpoint_convert +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class CheckpointConvertTest(test.TestCase): + + def setUp(self): + self._old_ckpt_path = tempfile.mktemp() + self._new_ckpt_path = tempfile.mktemp() + ops.reset_default_graph() + + def tearDown(self): + for file_name in glob.glob(self._old_ckpt_path + "*"): + os.remove(file_name) + for file_name in glob.glob(self._new_ckpt_path + "*"): + os.remove(file_name) + + def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self): + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + + def testConversionFromV2WithConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + variables.Variable(20.0, name=old_name) + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertTrue(glob.glob(self._new_ckpt_path + "*")) + self.assertItemsEqual( + ["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()), + new_var_map.keys()) + self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map) + + def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertItemsEqual(["a"], new_var_map.keys()) + self.assertFalse(conversion_map) + + def testConversionToV1Succeeds(self): + variables.Variable(10.0, name="a") + variables.Variable( + 20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]) + + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True) + self.assertItemsEqual( + ["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]], + new_var_map.keys()) + self.assertEqual( + {list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]: + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]}, + conversion_map) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 5ced8a4f089..b70d612f55b 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -261,7 +261,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging @@ -657,7 +657,7 @@ def train(train_op, if local_init_op == _USE_DEFAULT: local_init_op = control_flow_ops.group( tf_variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) if sync_optimizer is not None and isinstance( sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4cfdf844ce4..14deffc71bc 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [ "framework/versions.proto", "lib/core/error_codes.proto", "protobuf/config.proto", + "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", @@ -506,6 +507,7 @@ tf_gen_op_libs( "image_ops", "io_ops", "linalg_ops", + "lookup_ops", "logging_ops", "math_ops", "nn_ops", @@ -582,6 +584,7 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", + ":lookup_ops_op_lib", ":logging_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", @@ -708,6 +711,7 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", + "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index 78649afeb93..aa8a2d989bf 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -23,8 +23,7 @@ limitations under the License. namespace tensorflow { -Device::Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator) +Device::Device(Env* env, const DeviceAttributes& device_attributes) : DeviceBase(env), device_attributes_(device_attributes) { CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) << "Invalid device name: " << name(); diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 07c6bdd6831..c0e58f143e3 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -53,8 +53,7 @@ namespace tensorflow { class Device : public DeviceBase { public: - Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator); + Device(Env* env, const DeviceAttributes& device_attributes); ~Device() override; // Full name of this device (see top comment). diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 7807656cb25..31f12d48337 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector& devices) for (Device* d : devices) { devices_.push_back(d); - // Register under both the full name and the local name. + // Register under the (1) full name, (2) canonical name, and (3) local name. string full_name = d->name(); device_map_[CopyToBackingStore(full_name)] = d; + DeviceNameUtils::ParsedName parsed_name = d->parsed_name(); + if (parsed_name.has_job && parsed_name.has_replica && + parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) { + string canonical_name = DeviceNameUtils::FullName( + parsed_name.job, parsed_name.replica, parsed_name.task, + parsed_name.type, parsed_name.id); + device_map_[CopyToBackingStore(canonical_name)] = d; + } string lname = DeviceNameUtils::LocalName(d->name()); device_map_[CopyToBackingStore(lname)] = d; device_type_counts_[d->device_type()]++; @@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector& devices) } DeviceMgr::~DeviceMgr() { - for (auto p : devices_) delete p; + // TODO(b/37437134): Remove destructor after converting to std::unique_ptr. + for (Device* p : devices_) delete p; } StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { @@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const { Status s; auto iter = device_map_.find(name); if (iter == device_map_.end()) { + std::vector device_names; + for (auto&& itr : device_map_) { + device_names.push_back(itr.first); + } + LOG(WARNING) << "Unknown device: " << name + << " all devices: " << str_util::Join(device_names, ", "); return errors::InvalidArgument(name, " unknown device."); } *device = iter->second; diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index bb1ed726408..d16681ac59d 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -36,6 +36,7 @@ class DeviceMgr { public: // Takes ownership of each device in 'devices'. // TODO(zhifengc): Other initialization information. + // TODO(b/37437134): Use std::unique_ptr's to track ownership. explicit DeviceMgr(const std::vector& devices); ~DeviceMgr(); @@ -61,6 +62,7 @@ class DeviceMgr { int NumDeviceType(const string& type) const; private: + // TODO(b/37437134): Use std::unique_ptr's to track ownership. typedef gtl::InlinedVector DeviceVec; DeviceVec devices_; diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index b0540dfa95b..4cd56e583c0 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -39,7 +39,10 @@ class DeviceSet { // Set the device designated as the "client". This device // must also be registered via AddDevice(). - void set_client_device(Device* device) { client_device_ = device; } + void set_client_device(Device* device) { + DCHECK(client_device_ == nullptr); + client_device_ = device; + } // Returns a pointer to the device designated as the "client". Device* client_device() const { return client_device_; } diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index ff20ee94a7d..0507076c8c3 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -27,8 +27,7 @@ namespace { static Device* Dev(const char* type, const char* name) { class FakeDevice : public Device { public: - explicit FakeDevice(const DeviceAttributes& attr) - : Device(nullptr, attr, nullptr) {} + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 0e2343cfe3f..02f70d835d5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, int gpu_id, const string& physical_device_desc, Allocator* gpu_allocator, Allocator* cpu_allocator, bool sync_every_op, int32 max_streams) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit, - locality, physical_device_desc), - gpu_allocator), + : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU, + memory_limit, locality, + physical_device_desc)), gpu_allocator_(gpu_allocator), cpu_allocator_(cpu_allocator), gpu_id_(gpu_id), diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index 0a6342ed736..3f7c9f68dba 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo { }; LocalDevice::LocalDevice(const SessionOptions& options, - const DeviceAttributes& attributes, - Allocator* device_allocator) - : Device(options.env, attributes, device_allocator), - owned_tp_info_(nullptr) { + const DeviceAttributes& attributes) + : Device(options.env, attributes), owned_tp_info_(nullptr) { // If we're running on the CPU, log warnings if we're not compiled using the // best flags for performance. port::WarnAboutUnusedCPUFeatures(); diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h index d1c27c62481..84a4f66db4a 100644 --- a/tensorflow/core/common_runtime/local_device.h +++ b/tensorflow/core/common_runtime/local_device.h @@ -33,8 +33,8 @@ struct SessionOptions; // GPUDevice into more 'process-wide' abstractions. class LocalDevice : public Device { public: - LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, - Allocator* device_allocator); + LocalDevice(const SessionOptions& options, + const DeviceAttributes& attributes); ~LocalDevice() override; private: diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc new file mode 100644 index 00000000000..fa9713735ed --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.cc @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/renamed_device.h" + +namespace tensorflow { + +// TODO(saeta): Convert to returning a std::unique_ptr? +/* static */ +Device* RenamedDevice::NewRenamedDevice(const string& new_base, + Device* underlying, + bool owns_underlying) { + DeviceNameUtils::ParsedName parsed_name; + CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name)); + DeviceNameUtils::ParsedName underlying_parsed_name = + underlying->parsed_name(); + CHECK(underlying_parsed_name.has_type); + CHECK(underlying_parsed_name.has_id); + parsed_name.type = underlying_parsed_name.type; + parsed_name.id = underlying_parsed_name.id; + string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica, + parsed_name.task, parsed_name.type, + parsed_name.id); + DeviceAttributes attributes(underlying->attributes()); + attributes.set_name(name); + return new RenamedDevice(underlying, attributes, owns_underlying); +} + +RenamedDevice::RenamedDevice(Device* underlying, + const DeviceAttributes& attributes, + bool owns_underlying) + : Device(underlying->env(), attributes), + underlying_(underlying), + owns_underlying_(owns_underlying) {} + +RenamedDevice::~RenamedDevice() { + if (owns_underlying_) { + delete underlying_; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h new file mode 100644 index 00000000000..0158e18cedc --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -0,0 +1,119 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Wraps a device with a new name, delegating work to the wrapped device. +// +// This class is used to wrap local devices when using clusterspec propagation +// where the name of a particular device may change in the context of a given +// session. +class RenamedDevice : public Device { + public: + static Device* NewRenamedDevice(const string& new_base, Device* underlying, + bool owns_underlying); + ~RenamedDevice() override; + + // Below are virtual methods defined on DeviceBase + bool RequiresRecordingAccessedTensors() const override { + return underlying_->RequiresRecordingAccessedTensors(); + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { + return underlying_->tensorflow_cpu_worker_threads(); + } + + const GpuDeviceInfo* tensorflow_gpu_device_info() const override { + return underlying_->tensorflow_gpu_device_info(); + } + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return underlying_->GetAllocator(attr); + } + + Allocator* GetStepAllocator(AllocatorAttributes attr, + ResourceMgr* step_resource_manager) override { + return underlying_->GetStepAllocator(attr, step_resource_manager); + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { + return underlying_->eigen_cpu_device(); + } + +#ifdef TENSORFLOW_USE_SYCL + const Eigen::SyclDevice* eigen_sycl_device() const override { + return underlying_->eigen_sycl_device(); + } +#endif + + PerOpGpuDevice* MakeGpuDevice() override { + return underlying_->MakeGpuDevice(); + } + + void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, Allocator* allocator) override { + underlying_->ReinitializeGpuDevice(context, device, dc, allocator); + } + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); + } + + // Below are virtual methods defined on Device + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + underlying_->Compute(op_kernel, context); + } + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override { + underlying_->ComputeAsync(op_kernel, context, std::move(done)); + } + + void ConsumeListOfAccessedTensors( + DeviceContext* context, const TensorReferenceVector& tensors) override { + underlying_->ConsumeListOfAccessedTensors(context, tensors); + } + + Status Sync() override { return underlying_->Sync(); } + + Status MaybeRewriteGraph(const FunctionDefLibrary& library, + std::unique_ptr* graph) override { + return underlying_->MaybeRewriteGraph(library, graph); + } + + Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) override { + return underlying_->FillContextMap(graph, device_context_map); + } + + private: + RenamedDevice(Device* underlying, const DeviceAttributes& attributes, + bool owns_underlying); + Device* const underlying_; + const bool owns_underlying_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index bd84417b105..24f27af5f1a 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -66,7 +66,7 @@ class DummyOp : public OpKernel { class FakeDevice : public Device { private: explicit FakeDevice(const DeviceAttributes& device_attributes) - : Device(nullptr, device_attributes, nullptr) {} + : Device(nullptr, device_attributes) {} public: Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 60348e885f5..f5f8aab6946 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, const string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit, - locality), - allocator), + : LocalDevice(options, Device::BuildDeviceAttributes( + name, DEVICE_CPU, memory_limit, locality)), allocator_(allocator) {} ThreadPoolDevice::~ThreadPoolDevice() {} diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 0f5eb0cb320..d2a828f39f2 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -77,7 +77,6 @@ cc_library( ], deps = [ ":graph_mgr", - ":rendezvous_mgr_interface", ":worker_cache", "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", @@ -92,9 +91,9 @@ cc_library( deps = [ ":graph_mgr", ":worker_session", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", ], ) @@ -237,6 +236,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 5863727f19b..e68aea46ecd 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { -BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name) - : worker_env_(worker_env), worker_name_(worker_name) {} +BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env) + : worker_env_(worker_env) {} BaseRendezvousMgr::~BaseRendezvousMgr() { for (auto& p : table_) { @@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() { } } -Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { +RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) { return FindOrCreate(step_id); } @@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { mutex_lock l(mu_); Table::iterator iter = table_.find(step_id); if (iter == table_.end()) { - auto rr = Create(step_id, worker_env_, worker_name_); + auto rr = Create(step_id, worker_env_); iter = table_.insert({step_id, rr}).first; } iter->second->Ref(); @@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() { } } -BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, - const string& worker_name, - int64 step_id, +BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, bool tolerate_dup_recv) : env_(env), - worker_name_(worker_name), step_id_(step_id), - local_(NewLocalRendezvous(tolerate_dup_recv)) {} + local_(NewLocalRendezvous(tolerate_dup_recv)), + session_(nullptr) {} BaseRemoteRendezvous::~BaseRemoteRendezvous() { CHECK(active_.empty()); @@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name, return device_name.starts_with(worker_name); } +Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { + CHECK_NE(session, nullptr) << "session must not be null!"; + std::vector deferred_calls; + { + mutex_lock l(mu_); + if (session_ != nullptr) { + if (session_->worker_name == session->worker_name) { + LOG(INFO) << "Skipping rendezvous re-initialization."; + return Status::OK(); + } + Status s = errors::Internal( + "Double init! Worker names would have changed from: ", + session_->worker_name, " -> ", session->worker_name); + LOG(WARNING) << s; + return s; + } + session_ = session; + std::swap(deferred_calls, deferred_calls_); + } + for (DeferredCall& call : deferred_calls) { + RecvLocalAsyncInternal(call.parsed, std::move(call.done)); + } + return Status::OK(); +} + +WorkerSession* BaseRemoteRendezvous::session() { + mutex_lock l(mu_); + return session_; +} + +bool BaseRemoteRendezvous::is_initialized() { + mutex_lock l(mu_); + return is_initialized_locked(); +} + Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { @@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, { mutex_lock l(mu_); if (!status_.ok()) return status_; - } - if (!IsLocalDevice(worker_name_, parsed.src_device)) { - return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + DCHECK(is_initialized_locked()); + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { + return errors::InvalidArgument( + "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", + session_->worker_name); + } } // Buffers "val" and "device_context" in local_. return local_->Send(parsed, args, val, is_dead); @@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, bool is_src) { + // Cache session pointer to avoid repeatedly taking & releasing the lock + // (e.g. calling session()) + WorkerSession* sess = nullptr; { mutex_lock l(mu_); if (!status_.ok()) return status_; + if (!is_initialized_locked()) { + return errors::Internal("ValidateDevices called before initialization."); + } + sess = session_; } - if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) { + if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } - if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) { + if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { return errors::InvalidArgument("Invalid rendezvous key (dst): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } return Status::OK(); } @@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); + CHECK(is_initialized()) << "RecvAsync called when uninitialized."; Status s = ValidateDevices(parsed, false /*!is_src*/); if (!s.ok()) { done(s, Args(), recv_args, Tensor(), false); @@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { + { + mutex_lock l(mu_); + if (!is_initialized_locked()) { + // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a + // remote worker) before the RunStep (or PartialRunStep) RPC from the + // master arrives. RecvLocalAsync thus buffers the arguments until after + // the RemoteRendezvous is Initialize()'d, when it completes the + // rendezvous logic. At some point after Initialize() is called, a Tensor + // is produced locally that will then be sent in response to the incoming + // RPC. + DeferredCall call(parsed, std::move(done)); + deferred_calls_.push_back(call); + return; + } + } + RecvLocalAsyncInternal(parsed, std::move(done)); +} + +void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, + DoneCallback done) { Status s = ValidateDevices(parsed, true /* is_src */); if (!s.ok()) { done(s, Args(), Args(), Tensor(), false); @@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) { active_.erase(call); } +BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed, + DoneCallback done) + : parsed(parsed), done(std::move(done)) {} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index 447a75913d6..b252f45fe96 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -59,15 +59,17 @@ class BaseRecvTensorCall; // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). class BaseRendezvousMgr : public RendezvousMgrInterface { public: - explicit BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name); + explicit BaseRendezvousMgr(const WorkerEnv* worker_env); ~BaseRendezvousMgr() override; // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - Rendezvous* Find(int64 step_id) override; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + RemoteRendezvous* Find(int64 step_id) override; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. @@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { protected: virtual BaseRemoteRendezvous* Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) = 0; + const WorkerEnv* worker_env) = 0; private: // Maps step_id to rendezvous. @@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Not owned. const WorkerEnv* const worker_env_; - const string worker_name_; mutex mu_; Table table_ GUARDED_BY(mu_); @@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Buffering of Tensor values is delegated to a "local" Rendezvous // obtained from NewLocalRendezvous(). This class just adds // functionality to coordinate with remote workers. -class BaseRemoteRendezvous : public Rendezvous { +class BaseRemoteRendezvous : public RemoteRendezvous { public: - BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - int64 step_id, bool tolerate_dup_recv); + BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, + bool tolerate_dup_recv); + + // Upgrades the BaseRemoteRendezvous to full initialization. + Status Initialize(WorkerSession* session) override; // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. @@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous { // Removes "call" from active_ if "call" is in active_. void DeregisterCall(BaseRecvTensorCall* call); + WorkerSession* session(); + + bool is_initialized(); + ~BaseRemoteRendezvous() override; const WorkerEnv* const env_; // Not owned. - const string worker_name_; const int64 step_id_; private: @@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous { // Status given by StartAbort() if any. Status status_ GUARDED_BY(mu_); + WorkerSession* session_ GUARDED_BY(mu_); // Not owned. + + // Data structures to handle calls when partially initialized. + struct DeferredCall { + const ParsedKey parsed; + DoneCallback done; + + DeferredCall(const ParsedKey& parsed, DoneCallback done); + }; + std::vector deferred_calls_ GUARDED_BY(mu_); // Active outstanding RecvTensor calls. gtl::FlatSet active_ GUARDED_BY(mu_); + bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return session_ != nullptr; + } + // If "is_src" is true, checks that the rendezvous key "parsed"'s // source is in this process. If "is_src" is false, checks that the // rendezvous key "parsed"'s destination is in this process. @@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous { const Rendezvous::Args& out_args, const Tensor& in, Tensor* out, StatusCallback done); + // Must be called only if fully initialized. + void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); + TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); }; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index ce7ce372e85..5bde771e8de 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -46,10 +46,8 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr) - : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) { - CHECK(rendezvous_mgr) << "Rendezvous mgr was null"; +GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) + : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. Status status = @@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, }; popts.get_incarnation = [this](const string& name) -> int64 { Device* device = nullptr; - Status s = worker_env_->device_mgr->LookupDevice(name, &device); + Status s = device_mgr_->LookupDevice(name, &device); if (s.ok()) { return device->attributes().incarnation(); } else { @@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, ExecutionUnit* unit = &(item->units.back()); // Find the device. - Status s = - worker_env_->device_mgr->LookupDevice(device_name, &unit->device); + Status s = device_mgr_->LookupDevice(device_name, &unit->device); if (!s.ok()) { // Remove the empty unit from the item as the item destructor wants all // units to have valid devices. @@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, // Function library runtime. unit->lib = NewFunctionLibraryRuntime( - worker_env_->device_mgr, worker_env_->env, unit->device, + device_mgr_, worker_env_->env, unit->device, subgraph->versions().producer(), item->lib_def, graph_options.optimizer_options()); @@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, } Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = SendInputsToRendezvous(rendezvous, in); rendezvous->Unref(); return s; } Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = RecvOutputsFromRendezvous(rendezvous, out); rendezvous->Unref(); return s; @@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, StatusCallback done) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); RecvOutputsFromRendezvousAsync(rendezvous, out, [done, rendezvous](const Status s) { rendezvous->Unref(); @@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, + WorkerSession* session, + const ExecutorOpts& /*opts*/, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, @@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, return; } - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); + Status s = rendezvous->Initialize(session); // Sends values specified by the caller. - Status s = SendInputsToRendezvous(rendezvous, in); + if (s.ok()) { + s = SendInputsToRendezvous(rendezvous, in); + } + if (!s.ok()) { done(s); item->Unref(); @@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, StatusCallback done) { const int num_units = item->units.size(); CHECK_GE(num_units, 1); - ScopedStepContainer* step_container = - new ScopedStepContainer(step_id, [this](const string& name) { - worker_env_->device_mgr->ClearContainers({name}); - }); + ScopedStepContainer* step_container = new ScopedStepContainer( + step_id, + [this](const string& name) { device_mgr_->ClearContainers({name}); }); // NOTE: Transfer one ref of rendezvous and item. ExecutorBarrier* barrier = new ExecutorBarrier(num_units, rendezvous, diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 349af6c54e5..50391f47e4d 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -37,6 +37,8 @@ namespace tensorflow { class ExecutorOpts; class StepStatsCollector; class RendezvousMgrInterface; +class DeviceMgr; +struct WorkerSession; // GraphMgr keeps track of a set of graphs that are registered with a // TensorFlow worker. Each registered graph is identified by a handle @@ -62,8 +64,7 @@ class RendezvousMgrInterface; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle" @@ -78,8 +79,8 @@ class GraphMgr { typedef std::map NamedTensors; typedef std::function StatusCallback; void ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, StepStatsCollector* collector, - CostGraphDef* cost_graph, + WorkerSession* session, const ExecutorOpts& opts, + StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, const NamedTensors& in, StatusCallback done); @@ -131,7 +132,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - RendezvousMgrInterface* rendezvous_mgr_; // Not owned. + DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index b4adee3bf6c..e860c99d953 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -34,6 +34,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -48,12 +49,17 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +namespace { +const char* const kGrpcProtocol = "grpc://"; +} // namespace + Master::Master(MasterEnv* env, double session_gc_seconds) : env_(env), last_1000_steps_(1000), @@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { Status status; + WorkerCacheFactoryOptions worker_cache_factory_options; + string grpc_protocol("grpc"); + worker_cache_factory_options.protocol = &grpc_protocol; auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); status = ValidateExternalGraphDefSyntax(req->graph_def()); if (!status.ok()) return; - // Ping all the workers and build the list of devices that the - // session will use. + + // The following 4 variables are set differently, depending on whether this + // session uses a client-provided clusterspec or not. + WorkerCacheInterface* worker_cache = nullptr; + // Note: worker_cache_ptr will be null except if this session is using a + // client-supplied ClusterDef (ClusterSpec propagation). + std::unique_ptr worker_cache_ptr; + std::unique_ptr device_set; // TODO(saeta): Convert to std::make_unique when available. std::unique_ptr>> remote_devices( new std::vector>()); - status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), - env_, env_->worker_cache, - remote_devices.get()); - if (!status.ok()) return; + + if (req->config().has_cluster_def()) { + worker_cache_factory_options.cluster_def = &req->config().cluster_def(); + + // Set the server_def's job_name and task_index fields. + string normalized_string; + string grpc_protocol(kGrpcProtocol); + if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) == + 0) { + normalized_string = + req->target().substr(grpc_protocol.length(), string::npos); + } else { + normalized_string = req->target(); + } + for (auto&& job : req->config().cluster_def().job()) { + for (auto&& task : job.tasks()) { + if (task.second == normalized_string) { + if (worker_cache_factory_options.job_name != nullptr) { + status = errors::InvalidArgument( + "Found multiple matching tasks that correspond to " + "to the master. Master target: '", + req->target(), "'. ClusterDef: ", + req->config().cluster_def().ShortDebugString()); + LOG(ERROR) << status; + return; + } + if (env_->local_devices[0]->parsed_name().job == job.name() && + env_->local_devices[0]->parsed_name().task == task.first) { + // TODO(b/37868888): Remove this limitation when resolved + status = errors::InvalidArgument( + "The ClusterSpec names the job and task index to be the same " + "names that were provided when the server booted. This is " + "currently not allowed. Job: ", + job.name(), ", task index: ", task.first); + return; + } + worker_cache_factory_options.job_name = &job.name(); + worker_cache_factory_options.task_index = task.first; + } + } + } + + // Create the worker cache from the computed server_def. + status = env_->worker_cache_factory(worker_cache_factory_options, + &worker_cache); + if (!status.ok()) return; + worker_cache_ptr = std::unique_ptr(worker_cache); + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + DeviceNameUtils::ParsedName name = d->parsed_name(); + if (name.job == *worker_cache_factory_options.job_name && + name.task == worker_cache_factory_options.task_index && + name.type == "CPU") { + device_set->set_client_device(d.get()); + } + } + } else { + worker_cache = env_->worker_cache; + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + } + int num_local_devices = 0; + for (Device* d : env_->local_devices) { + device_set->AddDevice(d); + if (num_local_devices == 0) { + // Uses the first local device as the client device. + device_set->set_client_device(d); + } + num_local_devices++; + } + } + + CHECK(device_set->client_device()); + SessionOptions options; options.config = req->config(); - MasterSession* session = - env_->master_session_factory(options, env_, std::move(remote_devices)); + + MasterSession* session = env_->master_session_factory( + options, env_, std::move(remote_devices), std::move(worker_cache_ptr), + std::move(device_set)); + GraphDef* gdef = const_cast(req)->mutable_graph_def(); - status = session->Create(gdef); + + status = session->Create(gdef, worker_cache_factory_options); if (!status.ok()) { session->Close().IgnoreError(); session->Unref(); diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index a155bd384d8..bb548adda15 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -19,17 +19,41 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { class Device; +class DeviceSet; class Env; class MasterSession; class OpRegistryInterface; class WorkerCacheInterface; +// Options passed to the worker_cache_factory function. +struct WorkerCacheFactoryOptions { + const ClusterDef* cluster_def = nullptr; + const string* job_name = nullptr; + int task_index; + const string* protocol = nullptr; + + WorkerCacheFactoryOptions() {} + + // Construct from a ServerDef proto. + // + // Note: server_def must outlive WorkerCacheFactoryOptions! + WorkerCacheFactoryOptions(const ServerDef& server_def) { + if (server_def.has_cluster() && !server_def.job_name().empty()) { + cluster_def = &server_def.cluster(); + job_name = &server_def.job_name(); + task_index = server_def.task_index(); + protocol = &server_def.protocol(); + } + } +}; + // The master environment class, which holds a bag of pointers to // per-master state. // @@ -57,8 +81,14 @@ struct MasterEnv { // `MasterEnv*` is retained by the caller. std::function>>)> + std::unique_ptr>>, + std::unique_ptr, + std::unique_ptr device_set)> master_session_factory; + + std::function + worker_cache_factory; }; } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index f7b422b70e3..50c5d90fc98 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -162,7 +164,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Partitions the graph into subgraphs and registers them on // workers. Status RegisterPartitions(const PartitionOptions& popts, - const FunctionDefLibrary& func_def_lib); + const FunctionLibraryDefinition& flib_def); // Runs one step of all partitions. Status RunPartitions(const MasterEnv* env, int64 step_id, @@ -273,7 +275,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { }; Status MasterSession::ReffedClientGraph::RegisterPartitions( - const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) { + const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) { { // Ensure register once. mu_.lock(); if (!init_started_) { @@ -292,7 +294,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( graph_defs_for_publishing.push_back(&name_def.second); } stats_publisher_->PublishGraphProto(graph_defs_for_publishing); - s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs)); + s = DoRegisterPartitions(popts, flib_def.ToProto(), + std::move(graph_defs)); } mu_.lock(); init_result_ = s; @@ -527,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( c->req->set_is_partial(is_partial_); c->req->set_is_last_partial_run(is_last_partial_run); } + c->req->set_session_handle(session_handle_); c->req->set_graph_handle(part.graph_handle); c->req->set_step_id(step_id); *c->req->mutable_exec_opts() = exec_opts; @@ -870,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { // The graph handle may be empty if we failed during partition registration. if (!part.graph_handle.empty()) { Call* c = new Call; + c->req.set_session_handle(session_handle_); c->req.set_graph_handle(part.graph_handle); // NOTE(mrry): We must capture `worker_cache_` since `this` // could be deleted before the callback is called. @@ -972,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) { MasterSession::MasterSession( const SessionOptions& opt, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory) : session_opts_(opt), env_(env), handle_(strings::FpToString(random::New64())), remote_devs_(std::move(remote_devs)), + worker_cache_(std::move(worker_cache)), + devices_(std::move(device_set)), stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), run_graphs_(5), partial_run_graphs_(5) { UpdateLastAccessTime(); + CHECK(devices_) << "device_set was null!"; VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() << " #remote " << remote_devs_->size(); - for (auto&& d : *remote_devs_) { - devices_.AddDevice(d.get()); - } - int num_local_devices = 0; - for (Device* d : env->local_devices) { - devices_.AddDevice(d); - if (num_local_devices == 0) { - // Uses the first local device as the client device. - devices_.set_client_device(d); - } - num_local_devices++; - } + LOG(INFO) << "Start master session " << handle_ << " with config: " << std::endl << session_opts_.config.DebugString(); @@ -1011,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() { last_access_time_usec_.store(Env::Default()->NowMicros()); } -Status MasterSession::Create(GraphDef* graph_def) { +Status MasterSession::Create(GraphDef* graph_def, + const WorkerCacheFactoryOptions& options) { if (session_opts_.config.graph_options().place_pruned_graph()) { // TODO(b/29900832): Fix this or remove the option. LOG(WARNING) << "Distributed session does not support the " @@ -1019,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) { session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); } - SimpleGraphExecutionStateOptions options; - options.device_set = &devices_; - options.session_options = &session_opts_; + SimpleGraphExecutionStateOptions execution_options; + execution_options.device_set = devices_.get(); + execution_options.session_options = &session_opts_; { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( - graph_def, options, &execution_state_)); + graph_def, execution_options, &execution_state_)); + } + if (options.cluster_def != nullptr) { + return CreateWorkerSessions(options); } return Status::OK(); } +Status MasterSession::CreateWorkerSessions( + const WorkerCacheFactoryOptions& options) { + CHECK(worker_cache_) << "CreateWorkerSessions should be called only with " + << "dynamic cluster membership."; + std::vector worker_names; + worker_cache_->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + CreateWorkerSessionRequest request; + CreateWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + *workers[i].request.mutable_server_def()->mutable_cluster() = + *options.cluster_def; + workers[i].request.mutable_server_def()->set_protocol(*options.protocol); + + DeviceNameUtils::ParsedName name; + if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) { + status = errors::Internal("Could not parse name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + if (!name.has_job || !name.has_task) { + status = errors::Internal("Incomplete worker name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + + workers[i].request.mutable_server_def()->set_job_name(name.job); + workers[i].request.mutable_server_def()->set_task_index(name.task); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->CreateWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); @@ -1059,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req, return Status::OK(); } +WorkerCacheInterface* MasterSession::get_worker_cache() const { + if (worker_cache_) { + return worker_cache_.get(); + } + return env_->worker_cache; +} + Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** rcg, bool is_partial) { const uint64 hash = HashBuildGraphOptions(opts); @@ -1082,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, << "\n"; std::unique_ptr client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); + WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( handle_, opts, std::move(client_graph), session_opts_, stats_publisher_factory_, execution_state_.get(), is_partial, - env_->worker_cache); - + worker_cache); iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } @@ -1161,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, return errors::FailedPrecondition("Session is closed."); } ++num_running_; + // Note: all code paths must eventually call MarkRunCompletion() + // in order to appropriate decrement the num_running_ counter. } Status status; if (!req.partial_run_handle().empty()) { @@ -1168,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, } else { status = DoRunWithLocalExecution(opts, req, resp); } - { - mutex_lock l(mu_); - --num_running_; - if (num_running_ == 0) { - num_running_is_zero_.notify_all(); - } - } return status; } +// Decrements num_running_ and broadcasts if num_running_ is zero. +void MasterSession::MarkRunCompletion() { + mutex_lock l(mu_); + --num_running_; + if (num_running_ == 0) { + num_running_is_zero_.notify_all(); + } +} + Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { // Registers subgraphs if haven't done so. PartitionOptions popts; @@ -1187,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { return strings::StrCat(prefix, "_S", next_node_id_++); }; popts.get_incarnation = [this](const string& name) -> int64 { - Device* d = devices_.FindDeviceByName(name); + Device* d = devices_->FindDeviceByName(name); if (d == nullptr) { return PartitionOptions::kIllegalIncarnation; } else { @@ -1214,7 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { } TF_RETURN_IF_ERROR( - rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto())); + rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def)); return Status::OK(); } @@ -1222,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); const string& prun_handle = req.partial_run_handle(); RunState* run_state = nullptr; { @@ -1320,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts, rcg->Ref(); rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), req.options(), resp->mutable_metadata()); + cleanup.release(); // MarkRunCompletion called in done closure. rcg->CleanupPartitionsAsync( run_state->step_id, [this, rcg, prun_handle](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); mutex_lock l(mu_); partial_runs_.erase(prun_handle); @@ -1367,10 +1457,10 @@ Status MasterSession::CreateDebuggerState( Status MasterSession::DoRunWithLocalExecution( CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { - VLOG(2) << "DoRunWithLocalExecution " - << "req: " << req.DebugString(); + VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); PerStepState pss; pss.start_micros = Env::Default()->NowMicros(); + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); // Prepare. BuildGraphOptions bgopts; @@ -1437,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution( } } rcg->Ref(); - rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { + cleanup.release(); // MarkRunCompletion called in done closure. + rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); return s; } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index d47125be992..3acc5bc5f0a 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted { MasterSession( const SessionOptions& options, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory); // Initialize the MasterSession for "def". Must be called before Extend(), // Run(), or Close(). // // After this method returns, `def` will no longer be valid. - Status Create(GraphDef* def); + Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options); // Returns the session handle. const string& handle() const { return handle_; } @@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted { std::unique_ptr>> remote_devs_; + // The optional session-specific worker cluster. + // TODO(saeta): Convert to std::optional when available. + std::unique_ptr worker_cache_; + // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. + WorkerCacheInterface* get_worker_cache() const; + // The device set used by this session. - DeviceSet devices_; + std::unique_ptr devices_; StatsPublisherFactory stats_publisher_factory_; @@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted { // Private dtor. The client must call Close(). virtual ~MasterSession(); + // Creates sessions on all workers. + // + // If this session is operating using the new ClusterSpec propagation behavior + // call this method in order to propagate the cluster membership to all + // workers. + Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, @@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted { MutableRunStepResponseWrapper* resp); Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp); + void MarkRunCompletion(); void UpdateLastAccessTime(); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 7b58feb93cc..b077975ea50 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const { const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } +const string& InMemoryRunGraphRequest::session_handle() const { + return session_handle_; +} + +void InMemoryRunGraphRequest::set_session_handle(const string& handle) { + session_handle_ = handle; +} + const string& InMemoryRunGraphRequest::graph_handle() const { return graph_handle_; } @@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run( const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { if (!proto_version_) { proto_version_.reset(new RunGraphRequest); + proto_version_->set_session_handle(session_handle()); proto_version_->set_graph_handle(graph_handle()); proto_version_->set_step_id(step_id()); *proto_version_->mutable_exec_opts() = exec_opts(); @@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { return *proto_version_; } +const string& MutableProtoRunGraphRequest::session_handle() const { + return request_.session_handle(); +} + +void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { + request_.set_session_handle(handle); +} + const string& MutableProtoRunGraphRequest::graph_handle() const { return request_.graph_handle(); } @@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) : request_(request) {} +const string& ProtoRunGraphRequest::session_handle() const { + return request_->session_handle(); +} + const string& ProtoRunGraphRequest::graph_handle() const { return request_->graph_handle(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 02516eabb4a..795a6add0e7 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -223,6 +223,10 @@ class RunGraphRequestWrapper { public: virtual ~RunGraphRequestWrapper() {} + // The session handle used to register the graph. If empty, a single global + // namespace is used. + virtual const string& session_handle() const = 0; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. virtual const string& graph_handle() const = 0; @@ -262,6 +266,7 @@ class RunGraphRequestWrapper { // See `RunGraphRequestWrapper` above for a description of the fields. class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { public: + virtual void set_session_handle(const string& handle) = 0; virtual void set_graph_handle(const string& handle) = 0; virtual void set_step_id(int64 step_id) = 0; virtual ExecutorOpts* mutable_exec_opts() = 0; @@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { void set_is_last_partial_run(bool is_last_partial_run) override; private: + string session_handle_; string graph_handle_; int64 step_id_; ExecutorOpts exec_opts_; @@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { ProtoRunGraphRequest(const RunGraphRequest* request); // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 9632e9c4398..91c1fb99fef 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/remote_device.h" #include + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) { class RemoteDevice : public Device { public: RemoteDevice(Env* env, const DeviceAttributes& da) - : Device(env, da, nullptr), - local_dev_name_(GetLocalDeviceName(da.name())) {} + : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, GetStatusResponse resp; }; Call* call = new Call; - auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { + auto cb = [env, worker_cache, worker_name, done, wi, + call](const Status& status) { + Status s = status; std::vector remote_devices; + auto cleanup = gtl::MakeCleanup( + [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] { + worker_cache->ReleaseWorker(worker_name, wi); + done(s, &remote_devices); + delete call; + }); if (s.ok()) { + DeviceNameUtils::ParsedName worker_name_parsed; + if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) || + !worker_name_parsed.has_job || !worker_name_parsed.has_replica || + !worker_name_parsed.has_task) { + s = errors::InvalidArgument("Could not parse worker name: ", + worker_name); + LOG(WARNING) << s; + return; + } remote_devices.reserve(call->resp.device_attributes_size()); for (const DeviceAttributes& da : call->resp.device_attributes()) { - auto d = new RemoteDevice(env, da); - remote_devices.push_back(d); + DeviceNameUtils::ParsedName device_name_parsed; + CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed)) + << "Device attribute name '" << da.name() << "' could not be " + << "parsed. Device Attribute: " << da.DebugString(); + // Preserve the exact name, if possible. + // TODO(b/37868888): Simplify when legacy device name formats removed. + if (device_name_parsed.job == worker_name_parsed.job && + device_name_parsed.replica == worker_name_parsed.replica && + device_name_parsed.task == worker_name_parsed.task) { + auto d = new RemoteDevice(env, da); + remote_devices.push_back(d); + } else { + DeviceAttributes da_rewritten = da; + da_rewritten.set_name(DeviceNameUtils::FullName( + worker_name_parsed.job, worker_name_parsed.replica, + worker_name_parsed.task, device_name_parsed.type, + device_name_parsed.id)); + auto d = new RemoteDevice(env, da_rewritten); + remote_devices.push_back(d); + } } } - worker_cache->ReleaseWorker(worker_name, wi); - done(s, &remote_devices); - delete call; }; wi->GetStatusAsync(&call->req, &call->resp, cb); } diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index 04c1fc248ef..43267d4362f 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -25,6 +25,23 @@ limitations under the License. namespace tensorflow { +struct WorkerSession; + +// RemoteRendezvous follow a 2-part initialization. First the objects are +// constructed. Eventually, they will be initialized. Clients of the +// RendezvousMgrInterface must guarantee to call Initialize on the returned +// RemoteRendezvous eventually. +// +// Partially initialized RemoteRendezvous must respect the Rendezvous interface +// (i.e. Send() must never block), however implementations are not expected to +// actually perform the underlying operations until after the RemoteRendezvous +// has been Initialize'd. +class RemoteRendezvous : public Rendezvous { + public: + // Fully construct the RemoteRendezvous. + virtual Status Initialize(WorkerSession* session) = 0; +}; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -51,7 +68,10 @@ class RendezvousMgrInterface { // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - virtual Rendezvous* Find(int64 step_id) = 0; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + virtual RemoteRendezvous* Find(int64 step_id) = 0; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 7160962b168..3867dd1f4d0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { }; // static utility function -RendezvousMgrInterface* NewRpcRendezvousMgr( - const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache) { - return new RpcRendezvousMgr(env, worker_name, worker_cache); +RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { + return new RpcRendezvousMgr(env); } } // namespace @@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() { // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. + // Shut down all outstanding rendezvous. + delete worker_env_.rendezvous_mgr; + // We must delete graph_mgr before device_mgr, due to shared // ownership of OpKernels in the executors. (The graph_mgr will // free all stateless OpKernels, and pass over borrowed stateful @@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. + } else { + // Note: session_mgr's legacy_session_ deletes device_mgr now. + delete worker_env_.device_mgr; } - delete worker_env_.device_mgr; // Do not delete (as these are not owned by the server): // - master_env_.env @@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() { // - worker_env_.compute_pool } -Status GrpcServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendevous_mgr_func) { +Status GrpcServer::Init( + ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, "/task:", server_def_.task_index()); TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, &master_env_.local_devices)); - worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices); + worker_env_.local_devices = master_env_.local_devices; + worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices); + worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr + ? new RpcRendezvousMgr(&worker_env_) + : rendezvous_mgr_func(&worker_env_); string unused; string default_worker_name; if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), @@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func, } WorkerCacheInterface* worker_cache; - TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache)); + WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); + TF_RETURN_IF_ERROR( + WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); CHECK_NE(nullptr, worker_cache); // Set up worker environment. - std::unique_ptr rendezvous_mgr( - rendevous_mgr_func == nullptr ? - new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) : - rendevous_mgr_func(&worker_env_, name_prefix, worker_cache)); worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { - return WorkerCacheFactory(server_def, worker_cache); + WorkerCacheFactoryOptions options(server_def); + return WorkerCacheFactory(options, worker_cache); }); worker_env_.compute_pool = ComputePool(sess_opts); @@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func, master_env_.master_session_factory = [config]( SessionOptions options, const MasterEnv* env, - std::unique_ptr>> remote_devs) { + std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set) { options.config.MergeFrom(config); return new MasterSession(options, env, std::move(remote_devs), + std::move(worker_cache), std::move(device_set), CreateNoOpStatsPublisher); }; + master_env_.worker_cache_factory = + [this](const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache) { + return WorkerCacheFactory(options, worker_cache); + }; // Provide direct access to the master from in-process clients. LocalMaster::Register(target(), master_impl_.get(), @@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, return Status::OK(); } -Status GrpcServer::Init() { - return Init(nullptr, nullptr); -} +Status GrpcServer::Init() { return Init(nullptr, nullptr); } -Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, +Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { - for (const auto& job : server_def.cluster().job()) { + for (const auto& job : options.cluster_def->job()) { std::map host_ports; for (const auto& task : job.tasks()) { string& host_port = host_ports[task.first]; @@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, task.first, "\": ", host_port, " and ", task.second); } - if (job.name() == server_def.job_name() && - task.first == server_def.task_index()) { + if (job.name() == *options.job_name && task.first == options.task_index) { host_port = strings::StrCat("localhost:", bound_port_); } else { host_port = task.second; @@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, return Status::OK(); } -Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, +Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { - string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", - "/task:", server_def.task_index()); + if (options.job_name == nullptr || options.job_name->empty()) { + Status s = errors::InvalidArgument( + "The master (current machine) is not included in the provided " + "cluster_def. ", + options.cluster_def->DebugString()); + LOG(WARNING) << s; + return s; + } GrpcChannelSpec channel_spec; - TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); + TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); + + std::unique_ptr channel_cache( + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); + + string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", + "/task:", options.task_index); - std::unique_ptr channel_cache(NewGrpcChannelCache( - channel_spec, GetChannelCreationFunction(server_def))); const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; @@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials( return ::grpc::InsecureServerCredentials(); } -ChannelCreationFunction GrpcServer::GetChannelCreationFunction( - const ServerDef& server_def) const { +ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const { // We can do this because SparseGrpcChannelCache is robust to nullptr being // returned by the channel creation function return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 3b66291a9ab..7b54bb84c88 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -37,9 +37,7 @@ class GrpcWorker; class Master; // function that creates a RendezvousMgr. -typedef std::function +typedef std::function RendezvousMgrCreationFunction; // function that registers a service to the server. The service needs to @@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface { protected: Status Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func); + const RendezvousMgrCreationFunction& rendezvous_mgr_func); Status Init(); @@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface { virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( const ServerDef& server_def) const; - virtual ChannelCreationFunction GetChannelCreationFunction( - const ServerDef& server_def) const; + virtual ChannelCreationFunction GetChannelCreationFunction() const; virtual std::unique_ptr CreateMaster(MasterEnv* master_env); // Creates a WorkerCacheInterface for a session. - Status WorkerCacheFactory(const ServerDef& server_def, + Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache); - // Parses a ServerDef into a GrpcChannelSpec. - Status ParseChannelSpec(const ServerDef& server_def, + // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. + Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec); // Returns the port to which this server is bound. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 1aacef8a26a..38d59d5bb59 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix); /* static */ Status GrpcSession::Create(const SessionOptions& options, std::unique_ptr* out_session) { - std::unique_ptr ret(new GrpcSession(options)); + std::unique_ptr session(new GrpcSession(options)); std::unique_ptr master; // For testing, we enable the client to disable the use of the local // master registry, so that the RPC stack is exercised. @@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options, options.target.substr(kSchemePrefixLength), &master_channel)); master.reset(NewGrpcMaster(master_channel)); } - ret->SetRemoteMaster(std::move(master)); - *out_session = std::move(ret); + session->SetRemoteMaster(std::move(master)); + *out_session = std::move(session); return Status::OK(); } @@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options, CreateSessionRequest req; *req.mutable_config() = options_.config; *req.mutable_graph_def() = graph; + req.set_target(options_.target); ReEncodeConsts(req.mutable_graph_def()); CreateSessionResponse resp; Status s = master_->CreateSession(call_options, &req, &resp); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index c11266587d8..873ef8588f4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // completes, and we may decide to bound some of the request // types. ENQUEUE_REQUEST(GetStatus, false); + ENQUEUE_REQUEST(CreateWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(GetStatus, false); } + void CreateWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->CreateWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CreateWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { @@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, ::grpc::ByteBuffer* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; @@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, // of execution of the callback lambda body below, an RPC // cancellation should abort the rendezvous. opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); - session->rendezvous_mgr->RecvLocalAsync( + env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, [opts, response, done, src_dev](const Status& status, const Rendezvous::Args& send_args, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 7518a289fdb..8265100061e 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -38,9 +38,8 @@ namespace { class RpcRemoteRendezvous : public BaseRemoteRendezvous { public: - RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* cache, int64 step_id) - : BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {} + RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) + : BaseRemoteRendezvous(env, step_id, false) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { private: ~RpcRemoteRendezvous() override {} - WorkerCacheInterface* const cache_; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); }; @@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() { return call_freelist; } -// A private cache that wraps worker_cache and allows reuse of -// WorkerInterface objects. -class WorkerFreeListCache : public WorkerCacheInterface { - public: - explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {} - - ~WorkerFreeListCache() { - for (auto p : workers_) { - wrapped_->ReleaseWorker(p.first, p.second.worker); - } - } - - void ListWorkers(std::vector* workers) const override { - wrapped_->ListWorkers(workers); - } - - WorkerInterface* CreateWorker(const string& target) override { - mutex_lock l(mu_); - auto p = workers_.find(target); - if (p != workers_.end()) { - return p->second.worker; - } - WorkerState state; - state.worker = wrapped_->CreateWorker(target); - if (state.worker != nullptr) { - workers_.insert(std::make_pair(target, state)); - } - return state.worker; - } - - void ReleaseWorker(const string& target, WorkerInterface* worker) override { - // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. - } - - bool GetDeviceLocalityNonBlocking(const string& device, - DeviceLocality* locality) override { - return wrapped_->GetDeviceLocalityNonBlocking(device, locality); - } - - void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, - StatusCallback done) override { - wrapped_->GetDeviceLocalityAsync(device, locality, done); - } - - void SetLogging(bool active) override { wrapped_->SetLogging(active); } - - void ClearLogs() override { wrapped_->ClearLogs(); } - - bool RetrieveLogs(int64 step_id, StepStats* ss) override { - return wrapped_->RetrieveLogs(step_id, ss); - } - - private: - WorkerCacheInterface* wrapped_; - - // Information kept per created WorkerInterface. - struct WorkerState { - WorkerInterface* worker; - // TODO(jeff,sanjay): Add reference count if we support eviction. - }; - - // TODO(jeff,sanjay): Eviction when the map becomes too big. - mutex mu_; - std::unordered_map workers_ GUARDED_BY(mu_); -}; - void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { + CHECK(is_initialized()); Status s; // Prepare a RecvTensor call that can handle being aborted. @@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = errors::Internal(parsed.src_device, " is invalid remote source device."); } - WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); + WorkerSession* sess = session(); + WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } Device* dst_device; if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { - get_call_freelist()->Release(call, cache_); + if (rwi != nullptr) { + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); + } + get_call_freelist()->Release(call, sess->worker_cache.get()); done(s, Args(), recv_args, Tensor{}, false); return; } @@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // current status should be bad. Status s = call->status(); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); - cache_->ReleaseWorker(call->src_worker_, call->wi_); + session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_); call->wi_ = nullptr; - get_call_freelist()->Release(call, cache_); + get_call_freelist()->Release(call, session()->worker_cache.get()); Unref(); }); } } // namespace -RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env, - const string& worker_name, - WorkerCacheInterface* worker_cache) - : BaseRendezvousMgr(env, worker_name), - cache_(new WorkerFreeListCache(worker_cache)) {} +RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env) {} BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) { - return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(), - step_id); + const WorkerEnv* worker_env) { + return new RpcRemoteRendezvous(worker_env, step_id); } } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h index 75dc62d98fd..34c48a79177 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { +class DeviceMgr; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -44,17 +44,12 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RpcRendezvousMgr : public BaseRendezvousMgr { public: - explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache); + explicit RpcRendezvousMgr(const WorkerEnv* env); protected: - BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& session_name) override; + BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env); private: - // Private cache_ that allows us to reuse WorkerInterface objects. - std::unique_ptr cache_; - TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); }; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 9b778eab3a5..2d0d76623d4 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test { : cache_(new DummyWorkerCache), worker_session_("/job:mnist/replica:1/task:2", std::unique_ptr(cache_), - std::unique_ptr(), + std::unique_ptr(), std::unique_ptr()), - rmgr_(&env, worker_session_.worker_name, cache_) { + rmgr_(&env) { env.env = Env::Default(); } @@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { // Explicit Abort(). const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, rendez]() { env.env->SleepForMicroseconds(100 * 1000); @@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } { // Cleanup causes Abort(). const int64 step_id = 321; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, step_id]() { env.env->SleepForMicroseconds(100 * 1000); @@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } } @@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); Rendezvous::Args args; args.device_context = dc; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); } { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2be62f816c..22551d54821 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -26,23 +27,12 @@ namespace tensorflow { SessionMgr::SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory) - : SessionMgr( - worker_env, default_worker_name, std::move(default_worker_cache), - default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {} - -SessionMgr::SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory) : worker_env_(worker_env), - legacy_session_( - default_worker_name, std::move(default_worker_cache), - std::unique_ptr(default_rendezvous_mgr), - std::unique_ptr( - new GraphMgr(worker_env, default_rendezvous_mgr))), + legacy_session_(default_worker_name, std::move(default_worker_cache), + std::unique_ptr(worker_env->device_mgr), + std::unique_ptr( + new GraphMgr(worker_env, worker_env->device_mgr))), worker_cache_factory_(std::move(worker_cache_factory)) {} string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { @@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { Status SessionMgr::CreateSession(const string& session, const ServerDef& server_def) { mutex_lock l(mu_); + if (session.empty()) { + return errors::InvalidArgument("Session must be non-empty."); + } + const string worker_name = WorkerNameFromServerDef(server_def); WorkerCacheInterface* worker_cache = nullptr; TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); - std::unique_ptr rendezvous_mgr( - new RpcRendezvousMgr(worker_env_, worker_name, worker_cache)); + std::vector renamed_devices; + for (Device* d : worker_env_->local_devices) { + renamed_devices.push_back( + RenamedDevice::NewRenamedDevice(worker_name, d, false)); + } + std::unique_ptr device_mgr(new DeviceMgr(renamed_devices)); std::unique_ptr graph_mgr( - new GraphMgr(worker_env_, rendezvous_mgr.get())); + new GraphMgr(worker_env_, device_mgr.get())); std::unique_ptr worker_session(new WorkerSession( worker_name, std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), std::move(graph_mgr))); + std::move(device_mgr), std::move(graph_mgr))); sessions_.insert(std::make_pair(session, std::move(worker_session))); return Status::OK(); @@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) { if (it != sessions_.end()) { sessions_.erase(it); } - std::set graph_handles; - for (auto graph_handle_it = sessions_by_graph_handle_.begin(); - graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) { - if (graph_handle_it->second == session) { - graph_handles.insert(graph_handle_it->first); - graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it); - if (graph_handle_it == sessions_by_graph_handle_.end()) break; - } - } - for (auto step_id_it = graphs_by_step_id_.begin(); - step_id_it != graphs_by_step_id_.end(); ++step_id_it) { - if (graph_handles.find(step_id_it->second) != graph_handles.end()) { - step_id_it = graphs_by_step_id_.erase(step_id_it); - if (step_id_it == graphs_by_step_id_.end()) break; - } - } return Status::OK(); } @@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) { WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; } -WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked( - const string& graph_handle) { - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it == sessions_by_graph_handle_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForSessionUnlocked(it->second); - } -} - -WorkerSession* SessionMgr::WorkerSessionForGraphHandle( - const string& graph_handle) { - mutex_lock l(mu_); - return WorkerSessionForGraphHandleUnlocked(graph_handle); -} - -WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it == graphs_by_step_id_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForGraphHandleUnlocked(it->second); - } -} - -void SessionMgr::AssociateGraphWithSession(const string& session, - const string& graph_handle) { - mutex_lock l(mu_); - sessions_by_graph_handle_[graph_handle] = session; -} - -void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) { - mutex_lock l(mu_); - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it != sessions_by_graph_handle_.end()) { - sessions_by_graph_handle_.erase(it); - } -} - -void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id) { - mutex_lock l(mu_); - graphs_by_step_id_[step_id] = graph_handle; -} - -void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it != graphs_by_step_id_.end()) { - graphs_by_step_id_.erase(it); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 455b5c8d9d9..c44bca7b7a4 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -30,6 +30,8 @@ struct WorkerEnv; // SessionMgr keeps track of information related to a given session. // +// SessionMgr runs on the workers. +// // SessionMgr is threadsafe. class SessionMgr { public: @@ -39,7 +41,6 @@ class SessionMgr { explicit SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory); ~SessionMgr() {} @@ -50,49 +51,36 @@ class SessionMgr { WorkerSession* WorkerSessionForSession(const string& session); WorkerSession* LegacySession(); - // Locates the worker session for a given graph handle - WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle); - void AssociateGraphWithSession(const string& session, - const string& graph_handle); - void DisassociateGraphFromSession(const string& graph_handle); - - // Locates a worker session for a given step id - WorkerSession* WorkerSessionForStepId(const int64 step_id); - void AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id); - void DisassociateStepIdFromGraph(const int64 step_id); - Status DeleteSession(const string& session); static string WorkerNameFromServerDef(const ServerDef& server_def); private: - // Private constructor to work around std::unique_ptr ownership issues. - explicit SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory); - const WorkerEnv* const worker_env_; // Not owned. + + // A note about destruction: + // We must delete graph_mgr before device_mgr, due to shared + // ownership of OpKernels in the executors. (The graph_mgr will + // free all stateless OpKernels, and pass over borrowed stateful + // OpKernels, which are also held in their respective devices' + // OpSegments.) + // + // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure + // that sessions_'s WorkerSessions are deleted (which do not own the + // underlying devices, but instead own RenamedDevices) before + // legacy_session_ is deleted. Further, we must ensure that WorkerSession's + // device_mgr is deleted after WorkerSession's graph_mgr. + WorkerSession legacy_session_; const WorkerCacheFactory worker_cache_factory_; WorkerSession* WorkerSessionForSessionUnlocked(const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_); - WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle) - EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; // A map from session identifier to internal session structure. std::map> sessions_ GUARDED_BY(mu_); - - // A map from graph handles to the session that they belong to. - std::map sessions_by_graph_handle_ GUARDED_BY(mu_); - - // A map from globally-unique step id's to the corresponding graph handles. - std::map graphs_by_step_id_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index d3f3fa83958..7132f123a59 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test { SessionMgrTest() : mgr_(&env_, "/job:mnist/replica:0/task:0", std::unique_ptr(), - std::unique_ptr(new RpcRendezvousMgr( - &env_, "/job:mnist/replica:0/task:0", nullptr)), factory_), legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {} @@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) { TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - + EXPECT_NE(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateGraphWithSession) { +TEST_F(SessionMgrTest, LegacySession) { ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); + string session_handle = ""; WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); + EXPECT_EQ(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateStepWithGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(session, step_session); - ASSERT_EQ(graph_session, step_session); - - TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); -} - -TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - TEST_F(SessionMgrTest, WorkerNameFromServerDef) { ServerDef server_def; server_def.set_job_name("worker"); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 89639e21b5d..07bb17981d3 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, Status s = session->graph_mgr->Register( request->session_handle(), request->graph_def(), request->graph_options(), request->debug_options(), response->mutable_graph_handle()); - if (s.ok()) { - env_->session_mgr->AssociateGraphWithSession(request->session_handle(), - response->graph_handle()); - } done(s); } @@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) { WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); Status s = session->graph_mgr->Deregister(request->graph_handle()); - env_->session_mgr->DisassociateGraphFromSession(request->graph_handle()); done(s); } @@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id, } void Worker::AbortStep(int64 step_id) { - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - Rendezvous* rendez = session->rendezvous_mgr->Find(step_id); + Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { // Delay a bit before aborting the step. This way, the root // cause may return first back to the client instead of this @@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, const int64 step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); - env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, } CostGraphDef* cost_graph = response->mutable_cost_graph(); session->graph_mgr->ExecuteAsync( - request->graph_handle(), step_id, request->exec_opts(), collector, - cost_graph, cm, in, + request->graph_handle(), step_id, session, request->exec_opts(), + collector, cost_graph, cm, in, [this, step_id, response, session, cm, out, token, collector, opts, done](Status s) { if (s.ok()) { @@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, const string& graph_handle = request->graph_handle(); TRACEPRINTF("PartialRunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(graph_handle); - env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); + GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, [cm]() { cm->StartCancel(); }); } session->graph_mgr->ExecuteAsync( - graph_handle, step_id, request->exec_opts(), nullptr /* collector */, - nullptr /* cost_graph */, cm, in, + graph_handle, step_id, session, request->exec_opts(), + nullptr /* collector */, nullptr /* cost_graph */, cm, in, [this, token, graph_handle, step_id, cm](Status s) { { mutex_lock l(mu_); @@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - session->rendezvous_mgr->Cleanup(step_id); + env_->rendezvous_mgr->Cleanup(step_id); done(Status::OK()); } @@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request, Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, Device** src_dev) { // Figures out which device the tensor is hosted on. - TF_RETURN_IF_ERROR( - env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); + string local_name = DeviceNameUtils::LocalName(parsed.src_device); + TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); // Does the device have the right incarnation number we expect? if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 24fb5948a71..f09bea328fd 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ +#include #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -24,8 +25,10 @@ namespace thread { class ThreadPool; } // namespace thread +class Device; class DeviceMgr; class Env; +class RendezvousMgrInterface; class SessionMgr; // The worker environment class, which holds a bag of pointers to @@ -38,10 +41,18 @@ struct WorkerEnv { // session_mgr encapsulates state for each session. SessionMgr* session_mgr = nullptr; + // The local devices of this worker. Devices are owned by the device_mgr. + // + // REQUIRES: !local_devices.empty(). + std::vector local_devices; + // device_mgr manages local devices (cpu and gpu). The WorkerService // is the network interface for managed devices. DeviceMgr* device_mgr = nullptr; + // A set of rendezvous keyed by step ids. + RendezvousMgrInterface* rendezvous_mgr = nullptr; + // A pool of threads for scheduling compute work. thread::ThreadPool* compute_pool = nullptr; }; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 508bc7f4680..c9db28ec67f 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -113,6 +113,11 @@ class WorkerInterface { return CallAndWait(&ME::GetStatusAsync, request, response); } + Status CreateWorkerSession(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response) { + return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 8298e169595..8691450e9bc 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -17,14 +17,84 @@ limitations under the License. namespace tensorflow { -WorkerSession::WorkerSession( - const string& worker_name, - std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, - std::unique_ptr graph_mgr) +namespace { + +// A private cache that wraps worker_cache and allows reuse of +// WorkerInterface objects. +class WorkerFreeListCache : public WorkerCacheInterface { + public: + explicit WorkerFreeListCache(std::unique_ptr w) + : wrapped_(std::move(w)) {} + + ~WorkerFreeListCache() final { + for (auto p : workers_) { + wrapped_->ReleaseWorker(p.first, p.second.worker); + } + } + + void ListWorkers(std::vector* workers) const override { + wrapped_->ListWorkers(workers); + } + + WorkerInterface* CreateWorker(const string& target) override { + mutex_lock l(mu_); + auto p = workers_.find(target); + if (p != workers_.end()) { + return p->second.worker; + } + WorkerState state; + state.worker = wrapped_->CreateWorker(target); + if (state.worker != nullptr) { + workers_.insert(std::make_pair(target, state)); + } + return state.worker; + } + + void ReleaseWorker(const string& target, WorkerInterface* worker) override { + // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. + } + + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return wrapped_->GetDeviceLocalityNonBlocking(device, locality); + } + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + wrapped_->GetDeviceLocalityAsync(device, locality, done); + } + + void SetLogging(bool active) override { wrapped_->SetLogging(active); } + + void ClearLogs() override { wrapped_->ClearLogs(); } + + bool RetrieveLogs(int64 step_id, StepStats* ss) override { + return wrapped_->RetrieveLogs(step_id, ss); + } + + private: + std::unique_ptr wrapped_; + + // Information kept per created WorkerInterface. + struct WorkerState { + WorkerInterface* worker; + // TODO(jeff,sanjay): Add reference count if we support eviction. + }; + + // TODO(jeff,sanjay): Eviction when the map becomes too big. + mutex mu_; + std::unordered_map workers_ GUARDED_BY(mu_); +}; + +} // namespace + +WorkerSession::WorkerSession(const string& worker_name, + std::unique_ptr worker_cache, + std::unique_ptr device_mgr, + std::unique_ptr graph_mgr) : worker_name(worker_name), - worker_cache(std::move(worker_cache)), - rendezvous_mgr(std::move(rendezvous_mgr)), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + device_mgr(std::move(device_mgr)), graph_mgr(std::move(graph_mgr)) {} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index e6ebe883298..77cf4de8f74 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -18,14 +18,13 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" namespace tensorflow { class GraphMgr; -class RendezvousMgrInterface; class WorkerCacheInterface; // WorkerSession encapsulates all of the state relating to a given session. @@ -36,17 +35,20 @@ struct WorkerSession { // Object from which WorkerInterface instances can be obtained. const std::unique_ptr worker_cache; - // A set of rendezvous keyed by step ids. - const std::unique_ptr rendezvous_mgr; + // Collection of local devices. These devices are typically RenamedDevices + // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr + // == worker_env_.device_mgr, which holds the true devices. + const std::unique_ptr device_mgr; // graph_mgr keeps track of the registered graphs of this session. // // Note: graph_mgr must be deleted before rendezvous_mgr! + // Note: graph_mgr must be deleted before device_mgr! const std::unique_ptr graph_mgr; WorkerSession(const string& worker_name, std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, + std::unique_ptr device_mgr, std::unique_ptr graph_mgr); }; diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 8894671fdf3..27fe28fe60a 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -115,7 +115,7 @@ class DeviceBase { cpu_worker_threads_ = t; } - const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { CHECK(cpu_worker_threads_ != nullptr); return cpu_worker_threads_; } @@ -140,7 +140,7 @@ class DeviceBase { gpu_device_info_ = g; } - const GpuDeviceInfo* tensorflow_gpu_device_info() const { + virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const { return gpu_device_info_; } @@ -170,13 +170,13 @@ class DeviceBase { return GetAllocator(attr); } - const Eigen::ThreadPoolDevice* eigen_cpu_device() { + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { CHECK(eigen_cpu_device_ != nullptr); return eigen_cpu_device_; } #ifdef TENSORFLOW_USE_SYCL - const Eigen::SyclDevice* eigen_sycl_device() const { + virtual const Eigen::SyclDevice* eigen_sycl_device() const { CHECK(eigen_sycl_device_ != nullptr); return eigen_sycl_device_; } diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 3d913cdaf0c..6fad379b760 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -656,22 +656,6 @@ Status OpKernelContext::allocate_persistent(DataType type, *out_tensor = out_persistent->AccessTensor(this); } } - if (track_allocations() && persistent.TotalBytes() > 0) { - // TODO(yuefengz): some allocators allocate memory even if the requested - // size is 0. - Allocator* a = get_allocator(attr); - if (a->TracksAllocationSizes()) { - int64 alloc_size = - a->AllocatedSize(const_cast(persistent.tensor_data().data())); - int64 alloc_id = - a->AllocationId(const_cast(persistent.tensor_data().data())); - if (allocate_on_host(attr)) { - record_host_persistent_memory_allocation(alloc_size, alloc_id); - } else { - record_device_persistent_memory_allocation(alloc_size, alloc_id); - } - } - } return s; } diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 22f4708d032..372092f42a9 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -111,6 +111,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + defines = if_cuda(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":op_performance_data_cc", @@ -167,3 +168,29 @@ cc_library( "//tensorflow/core/kernels:ops_util", ], ) + +cc_library( + name = "op_level_cost_estimator", + srcs = ["op_level_cost_estimator.cc"], + hdrs = ["op_level_cost_estimator.h"], + visibility = ["//visibility:public"], + deps = [ + ":cost_estimator", + ":op_performance_data_cc", + ":utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "op_level_cost_estimator_test", + srcs = ["op_level_cost_estimator_test.cc"], + deps = [ + ":op_level_cost_estimator", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc new file mode 100644 index 00000000000..baed7a88997 --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -0,0 +1,554 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/grappler/costs/utils.h" + +namespace tensorflow { +namespace grappler { + +constexpr int kOpsPerMac = 2; +constexpr char kConv2d[] = "Conv2D"; +constexpr char kConv2dBackPropFilter[] = "Conv2DBackpropFilter"; +constexpr char kConv2dBackPropInput[] = "Conv2DBackpropInput"; +constexpr char kMatMul[] = "MatMul"; +constexpr char kSparseMatMul[] = "SparseMatMul"; +constexpr char kIdentity[] = "Identity"; +constexpr char kNoOp[] = "NoOp"; +constexpr char kReshape[] = "Reshape"; + +OpLevelCostEstimator::OpLevelCostEstimator() { + // Syntactic sugar to build and return a lambda that takes an OpInfo and + // returns a cost. + typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature) + const; + auto wrap = [this](CostImpl impl) -> std::function { + return [this, impl](const OpInfo& op) { return (this->*impl)(op); }; + }; + + device_cost_impl_ = { + {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)}, + {kConv2dBackPropFilter, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropFilter)}, + {kConv2dBackPropInput, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropInput)}, + {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}}; +} + +Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const { + auto it = device_cost_impl_.find(op_features.op()); + if (it == device_cost_impl_.end()) { + VLOG(1) << "Missing implementation for op: " << op_features.op(); + Costs costs; + costs = DummyExecutionTime(op_features); + return costs; + } + + std::function estimator = it->second; + Costs costs = estimator(op_features); + VLOG(1) << "Operation " << op_features.op() << " takes " + << costs.execution_time.count() << " ns."; + return costs; +} + +std::pair OpLevelCostEstimator::GetDeviceInfo( + const OpInfo::DeviceProperties& device) const { + double gflops = -1; + double bandwidth = -1; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; + } + + if (device.type() == "CPU") { + const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo(); + // Check if vector instructions are available, and refine performance + // prediction based on this. + gflops = local_cpu.num_cores() * local_cpu.frequency(); + if (bandwidth < 0) { + if (local_cpu.bandwidth() > 0) { + bandwidth = local_cpu.bandwidth() / 1e6; + } else { + bandwidth = 32; + } + } + } else if (device.type() == "GPU") { + const OpInfo::DeviceProperties local_gpu = GetLocalGPUInfo(0); + const string architecture = local_gpu.environment().at("architecture"); + int cores_per_multiprocessor; + if (architecture < "3") { + // Fermi + cores_per_multiprocessor = 32; + } else if (architecture < "4") { + // Kepler + cores_per_multiprocessor = 192; + } else if (architecture < "6") { + // Maxwell + cores_per_multiprocessor = 128; + } else { + // Pascal. + cores_per_multiprocessor = 64; + } + gflops = local_gpu.num_cores() * local_gpu.frequency() * + cores_per_multiprocessor * kOpsPerMac; + if (bandwidth < 0) { + CHECK(local_gpu.bandwidth() > 0); + bandwidth = local_gpu.bandwidth() / 1e6; + } + } + + return std::make_pair(gflops, bandwidth); +} + +Costs OpLevelCostEstimator::DummyExecutionTime( + const OpInfo& op_features) const { + Costs costs = PredictOpCountBasedCost(0, op_features); + costs.inaccurate = true; + return costs; +} + +Costs OpLevelCostEstimator::PredictOpCountBasedCost( + double operations, const OpInfo& op_features) const { + std::pair device_perf = GetDeviceInfo(op_features.device()); + Costs::NanoSeconds compute_cost(operations / device_perf.first); + VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 + << " Execution Time (ns):" << compute_cost.count(); + + bool found_unknown_shapes = false; + double total_input_size = + CalculateInputSize(op_features, &found_unknown_shapes); + double total_output_size = + CalculateOutputSize(op_features, &found_unknown_shapes); + double total_io_size = total_input_size + total_output_size; + + Costs::NanoSeconds memory_cost(total_io_size / device_perf.second); + VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3 + << " Memory Time (ns):" << memory_cost.count(); + + Costs costs; + costs.compute_time = compute_cost; + costs.memory_time = memory_cost; + costs.execution_time = compute_cost + memory_cost; + costs.inaccurate = found_unknown_shapes; + return costs; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountConv2DOperations(op_features, nullptr, found_unknown_shapes); +} + +namespace { + +string GetDataFormat(const OpInfo& op_features) { + string data_format = "NHWC"; // Default format. + if (op_features.attr().find("data_format") != op_features.attr().end()) { + data_format = op_features.attr().at("data_format").s(); + } + return data_format; +} + +Padding GetPadding(const OpInfo& op_features) { + if (op_features.attr().find("padding") != op_features.attr().end() && + op_features.attr().at("padding").s() == "VALID") { + return Padding::VALID; + } + return Padding::SAME; // Default padding. +} + +std::vector GetStrides(const OpInfo& op_features) { + if (op_features.attr().find("strides") != op_features.attr().end()) { + const auto strides = op_features.attr().at("strides").list().i(); + return {strides[0], strides[1], strides[2], strides[3]}; + } + return {1, 1, 1, 1}; +} + +int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, + const Padding& padding) { + // Logic for calculating output shape is from GetWindowedOutputSizeVerbose() + // function in third_party/tensorflow/core/framework/common_shape_fns.cc. + if (padding == Padding::VALID) { + return (input - filter + stride) / stride; + } else { // SAME. + return (input + stride - 1) / stride; + } +} + +// Return a minimum shape if the shape is unknown. If known, return the original +// shape. +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes) { + auto shape = original_shape; + if (shape.unknown_rank()) { + *found_unknown_shapes = true; + } + if (shape.unknown_rank() || shape.dim_size() == 0) { + TensorShapeProto::Dim dim; + VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + dim.set_size(1); + for (int i = 0; i < rank; i++) { + *shape.add_dim() = dim; + } + } else { + CHECK_EQ(shape.dim_size(), rank); + for (int i = 0; i < rank; i++) { + if (shape.dim(i).size() == -1) { + *found_unknown_shapes = true; + VLOG(1) + << "WARNING: Use minimum dim size 1 because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + shape.mutable_dim(i)->set_size(1); + } + } + } + return shape; +} +} // namespace + +// Helper to translate the positional arguments into named fields. +OpLevelCostEstimator::ConvolutionDimensions +OpLevelCostEstimator::ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes) { + auto image_shape = + MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes); + auto filter_shape = + MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes); + + int x_index, y_index, channel_index; + const string& data_format = GetDataFormat(op_features); + if (data_format == "NCHW") { + x_index = 2; + y_index = 3; + channel_index = 1; + } else { + x_index = 1; + y_index = 2; + channel_index = 3; + } + int64 batch = image_shape.dim(0).size(); + int64 ix = image_shape.dim(x_index).size(); + int64 iy = image_shape.dim(y_index).size(); + int64 iz = image_shape.dim(channel_index).size(); + int64 kx = filter_shape.dim(0).size(); + int64 ky = filter_shape.dim(1).size(); + std::vector strides = GetStrides(op_features); + const auto padding = GetPadding(op_features); + int64 sx = strides[x_index]; + int64 sy = strides[y_index]; + int64 ox = GetOutputSize(ix, kx, sx, padding); + int64 oy = GetOutputSize(iy, ky, sy, padding); + int64 oz = filter_shape.dim(3).size(); + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (iz != 1 && filter_shape.dim(2).size() != 1) { + CHECK_EQ(iz, filter_shape.dim(2).size()); + } else { + iz = std::max(iz, filter_shape.dim(2).size()); + } + OpLevelCostEstimator::ConvolutionDimensions conv_dims = { + batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding}; + + VLOG(1) << "Batch Size:" << batch; + VLOG(1) << "Image Dims:" << ix << "," << iy; + VLOG(1) << "Input Features:" << iz; + VLOG(1) << "Kernel Dims:" << kx << "," << ky; + VLOG(1) << "Output Features:" << oz; + VLOG(1) << "Output Dims:" << ox << "," << oy; + VLOG(1) << "Strides:" << sx << "," << sy; + VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME"); + return conv_dims; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const { + if (op_features.op() != kConv2d) { + LOG(ERROR) << "Invalid Operation"; + return 0; + } + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + int64 ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + VLOG(1) << "Operations for Conv2D" << ops; + + if (conv_info != nullptr) { + *conv_info = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountMatMulOperations(op_features, nullptr, found_unknown_shapes); +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const { + double ops = 0; + + // TODO(nishantpatil): Create separate estimator for Sparse Matmul + if ((op_features.op() != kMatMul) && (op_features.op() != kSparseMatMul)) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + // first matrix + auto& a_matrix = op_features.inputs(0); + auto& b_matrix = op_features.inputs(1); + + bool transpose_a = false; + bool transpose_b = false; + + double m_dim, n_dim, k_dim, k_dim_b = 0; + + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "transpose_a" && item.second.b() == true) + transpose_a = true; + if (item.first == "transpose_b" && item.second.b() == true) + transpose_b = true; + } + VLOG(1) << "transpose_a:" << transpose_a; + VLOG(1) << "transpose_b:" << transpose_b; + auto a_matrix_shape = + MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes); + auto b_matrix_shape = + MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes); + if (transpose_a) { + m_dim = a_matrix_shape.dim(1).size(); + k_dim = a_matrix_shape.dim(0).size(); + } else { + m_dim = a_matrix_shape.dim(0).size(); + k_dim = a_matrix_shape.dim(1).size(); + } + if (transpose_b) { + k_dim_b = b_matrix_shape.dim(1).size(); + n_dim = b_matrix_shape.dim(0).size(); + } else { + k_dim_b = b_matrix_shape.dim(0).size(); + n_dim = b_matrix_shape.dim(1).size(); + } + + VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim; + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) { + LOG(ERROR) << "Incompatible Matrix dimensions"; + return ops; + } else { + // One of k_dim and k_dim_b might be 1 (mininum dimension size). + k_dim = std::max(k_dim, k_dim_b); + } + + ops = m_dim * n_dim * k_dim * 2; + VLOG(1) << "Operations for Matmul" << ops; + + if (mat_mul != nullptr) { + mat_mul->m = m_dim; + mat_mul->n = n_dim; + mat_mul->k = k_dim; + } + return ops; +} + +// TODO(cliffy): Dedup this method and CountConv2DBackPropFilterOperations. +int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + + if (op_features.op() != kConv2dBackPropInput) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for input shape. + LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure."; + return ops; + } + + const auto& input_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + input_shape, op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropInput" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + if (op_features.op() != kConv2dBackPropFilter) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for filter shape. + LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure."; + return ops; + } + + const auto& filter_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), filter_shape, op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropFilter" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CalculateSingleInputSize( + const OpInfo::TensorProperties& input, bool* found_unknown_shapes) const { + VLOG(1) << " with " << input.dtype() << " input of shape " + << input.shape().DebugString(); + int64 input_size = 1; + int num_dims = std::max(1, input.shape().dim_size()); + auto input_shape = + MaybeGetMinimumShape(input.shape(), num_dims, found_unknown_shapes); + for (const auto& dim : input_shape.dim()) { + input_size *= dim.size(); + } + return input_size * DataTypeSize(input.dtype()); +} + +int64 OpLevelCostEstimator::CalculateInputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_input_size = 0; + for (auto& input : op_features.inputs()) { + int64 input_size = CalculateSingleInputSize(input, found_unknown_shapes); + total_input_size += input_size; + VLOG(1) << "Input Size: " << input_size + << " Total Input Size:" << total_input_size; + } + return total_input_size; +} + +int64 OpLevelCostEstimator::CalculateOutputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_output_size = 0; + // use float as default for calculations + DataType dt = DT_FLOAT; + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "_output_shapes") { + for (const auto& original_output_shape : item.second.list().shape()) { + int64 output_size = 1; + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape( + original_output_shape, num_dims, found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); + } + output_size *= DataTypeSize(dt); + total_output_size += output_size; + VLOG(1) << "Output Size: " << output_size + << " Total Output Size:" << total_output_size; + } + } + if (item.first == "T") { + dt = item.second.type(); + } + } + return total_output_size; +} + +Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountConv2DOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropInput( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropInputOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropFilter( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropFilterOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountMatMulOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const { + VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; + return Costs::ZeroCosts(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h new file mode 100644 index 00000000000..5bb20cc6bbf --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ + +#include +#include +#include + +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { +namespace grappler { + +class OpLevelCostEstimator { + public: + OpLevelCostEstimator(); + virtual ~OpLevelCostEstimator() {} + + Costs PredictCosts(const OpInfo& op_features) const; + + protected: + // Returns an estimate of device performance (in billions of operations + // executed per second) and memory bandwith (in GigaBytes/second) for the + // specified device. + virtual std::pair GetDeviceInfo( + const OpInfo::DeviceProperties& device) const; + + // For operations for which we haven't yet built estimates, returns a dummy + // value based on input size. + Costs DummyExecutionTime(const OpInfo& op_features) const; + + // Naive cost estimate based on operations divided by device ops/sec. + Costs PredictOpCountBasedCost(double operations, + const OpInfo& op_features) const; + + // This family of routines counts the number of operations to perform the + // specified TensorFlow Op. + struct MatMulDimensions { + int m; + int n; + int k; + }; + struct ConvolutionDimensions { + int64 batch; // Batch size. + int64 ix; // Input size x. + int64 iy; // Input size y. + int64 iz; // Input depth. + int64 kx; // Kernel x. + int64 ky; // Kernel y. + int64 oz; // Output depth. + int64 ox; // Output size x. + int64 oy; // Output size y. + int64 sx; // Stride x. + int64 sy; // Stride y. + Padding padding; // SAME or VALID. + }; + int64 CountConv2DOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountConv2DOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropInputOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropFilterOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of a single input to a TensorFlow op. + int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the inputs of specified TensorFlow Op + int64 CalculateInputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the outputs of specified TensorFlow Op + int64 CalculateOutputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // This family of routines predicts the costs to + // perform the specified TensorFlow Op on the + // device represented by a subclass. The default + // implementation just divides the operations to + // perform the op (from the "Count" routines, + // above) by the device peak operations per + // second. Override to supply a better estimate. + // Implementation of costs other than + // execution_time is optional, depending on the + // device. + Costs PredictConv2D(const OpInfo& op_features) const; + Costs PredictConv2DBackPropInput(const OpInfo& op_features) const; + Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; + Costs PredictMatMul(const OpInfo& op_features) const; + Costs PredictNoOp(const OpInfo& op_features) const; + + // Utility function for safe division. Returns 0 + // if rhs is 0 or negative. + static double SafeDiv(const double lhs, const double rhs) { + if (rhs > 0) { + return lhs / rhs; + } else { + return 0.0; + } + } + + static ConvolutionDimensions ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes); + + private: + typedef std::function CostImpl; + std::map device_cost_impl_; +}; + +} // end namespace grappler +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc new file mode 100644 index 00000000000..e0b0348c8ec --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +namespace { +// Wrangles the minimum number of proto fields to set up a matrix. +void DescribeMatrix(int rows, int columns, OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + auto shape_rows = shape->add_dim(); + shape_rows->set_size(rows); + auto shape_columns = shape->add_dim(); + shape_columns->set_size(columns); + input->set_dtype(DT_FLOAT); +} + +// Returns an OpInfo for MatMul with the minimum set of fields set up. +OpInfo DescribeMatMul(int m, int n, int l, int k) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + DescribeMatrix(m, l, &op_features); + DescribeMatrix(k, n, &op_features); + return op_features; +} + +// Returns an OpInfo for MatMul with unknown input shapes. +OpInfo DescribeMatMulUnknownShape() { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + auto input = op_features.add_inputs(); + auto shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + input = op_features.add_inputs(); + shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + return op_features; +} + +// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost +// estimation purposes. +void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, + OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + shape->add_dim()->set_size(dim0); + shape->add_dim()->set_size(dim1); + shape->add_dim()->set_size(dim2); + shape->add_dim()->set_size(dim3); +} + +// Returns an OpInfo for Conv2D with the minimum set of fields set up. +OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx, + int ky, int oz) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("Conv2D"); + + DescribeTensor4D(batch, ix, iy, iz1, &op_features); + DescribeTensor4D(kx, ky, iz2, oz, &op_features); + return op_features; +} +} // namespace + +TEST(OpLevelCostEstimatorTest, UnknownOrPartialShape) { + OpLevelCostEstimator estimator; + + EXPECT_EQ(false, + estimator.PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate); + + EXPECT_EQ( + false, + estimator.PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256)) + .inaccurate); + EXPECT_EQ( + true, + estimator.PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256)) + .inaccurate); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 4e35de9d4a6..0852cb4fd3a 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() { // Combine cpu family and model into the model string. device.set_model( strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); - device.set_frequency(port::NominalCPUFrequency()); + device.set_frequency(port::NominalCPUFrequency() * 1e-9); device.set_num_cores(port::NumSchedulableCPUs()); device.set_l1_cache_size(Eigen::l1CacheSize()); device.set_l2_cache_size(Eigen::l2CacheSize()); @@ -195,6 +195,8 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) { properties.memoryClockRate * 2); } + (*device.mutable_environment())["architecture"] = + strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); #endif diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index b7785c94e04..5d437dff50e 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -26,7 +26,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace { struct NodeState { std::vector inputs; std::vector outputs; @@ -86,7 +85,6 @@ class FIFOManager : public ReadyNodeManager { private: std::list nodes_; }; -} // namespace // The virtual scheduler emulates execution of nodes in a graph, considering // dependencies, device, etc. diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index bafbcc200c4..64bdd910773 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -18,6 +18,11 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsConcat(const NodeDef& node) { + const auto op = node.op(); + return op == "Concat" || op == "ConcatV2"; +} + bool IsDequeueOp(const NodeDef& node) { static const std::set dequeue_ops = { "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", @@ -30,6 +35,11 @@ bool IsPlaceholder(const NodeDef& node) { return op == "Placeholder" || op == "PlaceholderV2"; } +bool IsTranspose(const NodeDef& node) { + const auto op = node.op(); + return op == "Transpose"; +} + bool IsVariable(const NodeDef& node) { const auto op = node.op(); return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 2f58835628d..4f2bb2bc056 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -21,8 +21,10 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsConcat(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsTranspose(const NodeDef& node); bool IsVariable(const NodeDef& node); } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e3b36c84123..5f30dfbaa26 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -205,11 +205,28 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", ], ) +cc_test( + name = "layout_optimizer_test", + srcs = ["layout_optimizer_test.cc"], + deps = [ + ":layout_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "meta_optimizer", srcs = ["meta_optimizer.cc"], diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 9570ec17d05..e37c4a5b36a 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -68,8 +69,7 @@ std::set GetOpsFormatAgnostic() { "Slice", "SquaredDifference", "Squeeze", - "Sub", - "Sum"}; + "Sub"}; return ops_format_agnostic; } @@ -110,9 +110,9 @@ class NodeProcessor { } protected: - bool IsDimsN(NodeDef* node, int n) const { - if (node->attr().find("_output_shapes") != node->attr().end()) { - auto shape = node->attr().at("_output_shapes").list().shape(0); + bool IsDimsN(const NodeDef& node, int n) const { + if (node.attr().find("_output_shapes") != node.attr().end()) { + auto shape = node.attr().at("_output_shapes").list().shape(0); if (shape.dim_size() == n) { return true; } @@ -120,7 +120,7 @@ class NodeProcessor { return false; } - bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); } + bool IsDimsFour(const NodeDef& node) const { return IsDimsN(node, 4); } bool IsNHWC() const { if (node_->attr().find("data_format") != node_->attr().end()) { @@ -145,7 +145,7 @@ class NodeProcessor { } virtual bool ShouldProcess() const { - return IsNHWC() && IsDimsFour(node_) && HasOutputs(); + return IsNHWC() && IsDimsFour(*node_) && HasOutputs(); } void UpdateAttrDataFormat() { @@ -268,6 +268,8 @@ class NodeProcessor { for (const auto& output : outputs) { string node_name_NCHWToNHWC = strings::StrCat( kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name()); + // TODO (yaozhang): handle the rare case where node A is connected to more + // than one input of node B. auto it = std::find_if(output->mutable_input()->begin(), output->mutable_input()->end(), [this](const string& input) { @@ -341,7 +343,7 @@ class BiasAddGradProcessor : public NodeProcessor { bool ShouldProcess() const override { auto input = node_map_->GetNode(node_->input(0)); if (input) { - if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) { + if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) { return true; } } @@ -351,13 +353,89 @@ class BiasAddGradProcessor : public NodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } }; -class Conv2DBackpropFilterProcessor : public NodeProcessor { +class Conv2DProcessor : public NodeProcessor { public: - Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool no_gemm) + : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {} protected: + bool ShouldProcess() const override { + return IsNHWC() && IsDimsFour(*node_) && HasOutputs() && + (!IsGemmUsed() || no_gemm_); + } + + TensorShapeProto GetShape(const string& input_name) const { + string node_name; + int output_pos; + node_name = ParseNodeName(input_name, &output_pos); + NodeDef* node = node_map_->GetNode(node_name); + if (node->attr().find("_output_shapes") != node->attr().end()) { + return node->attr().at("_output_shapes").list().shape(output_pos); + } + TensorShapeProto shape; + return shape; + } + + bool IsStrideOne() const { + if (node_->attr().find("strides") != node_->attr().end()) { + auto list = node_->attr().at("strides").list(); + return list.i(1) == 1 && list.i(2) == 1; + } + return false; + } + + bool IsValidPadding() const { + if (node_->attr().find("padding") != node_->attr().end()) { + auto padding = node_->attr().at("padding").s(); + return padding == "VALID"; + } + return false; + } + + // The logic inside this function is based on the internal implementation of + // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus + // needs to be updated accordingly if the internal implementation changes. + bool IsGemmUsed(const TensorShapeProto& filter_shape, + const TensorShapeProto& input_shape) const { + if (filter_shape.dim_size() == 4) { + if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 && + IsStrideOne()) { + return true; + } + } + if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) { + if (input_shape.dim(1).size() == filter_shape.dim(0).size() && + input_shape.dim(2).size() == filter_shape.dim(1).size() && + IsValidPadding()) { + return true; + } + } + return false; + } + + virtual bool IsGemmUsed() const { + auto filter_shape = GetShape(node_->input(1)); + auto input_shape = GetShape(node_->input(0)); + return IsGemmUsed(filter_shape, input_shape); + } + + bool no_gemm_; +}; + +class Conv2DBackpropFilterProcessor : public Conv2DProcessor { + public: + Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, + NodeMap* node_map, bool no_gemm) + : Conv2DProcessor(graph, node, node_map, no_gemm) {} + + protected: + bool IsGemmUsed() const override { + auto filter_shape = GetShape(node_->name()); + auto input_shape = GetShape(node_->input(0)); + return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape); + } + std::vector GetInputPos() const override { std::vector input_pos = {0, 2}; return input_pos; @@ -370,17 +448,24 @@ class Conv2DBackpropFilterProcessor : public NodeProcessor { void UpdateAttrShape() override {} }; -class Conv2DBackpropInputProcessor : public NodeProcessor { +class Conv2DBackpropInputProcessor : public Conv2DProcessor { public: Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + NodeMap* node_map, bool no_gemm) + : Conv2DProcessor(graph, node, node_map, no_gemm) {} protected: + bool IsGemmUsed() const override { + auto filter_shape = GetShape(node_->input(1)); + auto input_shape = GetShape(node_->name()); + return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape); + } + std::vector GetInputPos() const override { std::vector input_pos = {2}; return input_pos; } + Status CustomizedProcessing() override { NodeDef* node = node_map_->GetNode(node_->input(0)); return UpdateAttrValue(node); @@ -418,7 +503,7 @@ class AgnosticNodeProcessor : public NodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC(); + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC(); } bool IsNodeAfterNCHWToNHWC() const { @@ -467,7 +552,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && (Is4DOperateWithND(4) || Is4DOperateWithScalar() || Is4DOperateWithVector()); } @@ -484,10 +569,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { auto input0 = node_map_->GetNode(node_->input(0)); auto input1 = node_map_->GetNode(node_->input(1)); if (input0 && input1) { - return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && ((n == 4) - ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name())) - : IsDimsN(input1, n)); + ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name())) + : IsDimsN(*input1, n)); } return false; } @@ -571,7 +656,7 @@ class ConcatProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && IsAlongDimC(); } @@ -739,7 +824,7 @@ class SqueezeProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsN(*node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW(); } @@ -790,7 +875,7 @@ class SumProcessor : public AgnosticNodeProcessor { bool ShouldProcess() const override { auto input0 = node_map_->GetNode(node_->input(0)); return HasOutputs() && IsNodeAfterNCHWToNHWC() && - (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && IsAlongDimNHW(); } @@ -825,10 +910,21 @@ class SumProcessor : public AgnosticNodeProcessor { } }; +struct TuningConfig { + // If true, do not use the NHWC GEMM implementation. When filter size is + // one or filter size is equal to input image size, + // the NHWC implementation of Conv2D, Conv2DBackpropInput, and + // Conv2DBackpropFilter will use a specialized GEMM implementation, which is + // usually faster than the NCHW implementation. The downside is that this + // might result in more non-cancellable layout conversion nodes (implemented + // by the Tranpose op). + bool no_gemm; +}; + class DataLayoutOptimizer { public: - explicit DataLayoutOptimizer(GraphDef* graph) - : graph_(graph), node_map_(graph_) {} + explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config) + : graph_(graph), node_map_(graph_), config_(config) {} Status Optimize() { LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); @@ -908,12 +1004,15 @@ class DataLayoutOptimizer { } else if (node->op().compare("BiasAddGrad") == 0) { node_processor.reset( new BiasAddGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Conv2D") == 0) { + node_processor.reset( + new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("Conv2DBackpropFilter") == 0) { - node_processor.reset( - new Conv2DBackpropFilterProcessor(graph_, node, &node_map_)); + node_processor.reset(new Conv2DBackpropFilterProcessor( + graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("Conv2DBackpropInput") == 0) { - node_processor.reset( - new Conv2DBackpropInputProcessor(graph_, node, &node_map_)); + node_processor.reset(new Conv2DBackpropInputProcessor( + graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("FusedBatchNormGrad") == 0) { node_processor.reset( new FusedBatchNormGradProcessor(graph_, node, &node_map_)); @@ -1025,17 +1124,46 @@ class DataLayoutOptimizer { GraphDef* graph_; NodeMap node_map_; + TuningConfig config_; }; +int GetNumTranspose(const GraphDef& graph) { + int number = 0; + for (const auto& node : graph.node()) { + if (IsTranspose(node)) { + number++; + } + } + LOG(INFO) << "Number of Transpose nodes: " << number; + return number; +} + Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { - if (GetNumAvailableGPUs() < 1) { + if (num_gpus_ == 0) { + num_gpus_ = GetNumAvailableGPUs(); + } + if (num_gpus_ < 1) { // LayoutOptimizer is currently only tuned for GPU. return Status::OK(); } + *output = item.graph; - DataLayoutOptimizer layout_optimizer(output); + TuningConfig config; + config.no_gemm = false; + DataLayoutOptimizer layout_optimizer(output, config); auto status = layout_optimizer.Optimize(); + + // This is based on an empirical observation that if the introduced Transpose + // nodes is more than 30, not using GEMM implementation would result in better + // performance. + if (status.ok() && GetNumTranspose(*output) > 30) { + *output = item.graph; + config.no_gemm = true; + DataLayoutOptimizer layout_optimizer(output, config); + status = layout_optimizer.Optimize(); + } + if (!status.ok()) { *output = item.graph; } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h index 66dec17a35c..1bd6f9544b1 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.h +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h @@ -29,11 +29,17 @@ class LayoutOptimizer : public GraphOptimizer { string name() const override { return "layout"; }; + // This is for testing only. + void set_num_gpus(int num_gpus) { num_gpus_ = num_gpus; }; + Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) override; void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& optimize_output, double result) override; + + private: + int num_gpus_ = 0; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc new file mode 100644 index 00000000000..be38ca1a69e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +void AddOutputShape(Node* node, const TensorShape& shape) { + std::vector output_shapes; + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + output_shapes.push_back(shape_proto); + node->AddAttr("_output_shapes", output_shapes); +} + +class LayoutOptimizerTest : public ::testing::Test { + protected: + Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size, + const string& padding) { + int batch_size = 128; + int input_height = input_size; + int input_width = input_size; + int input_depth = 3; + int filter_count = 2; + int stride = 1; + TensorShape input_shape( + {batch_size, input_height, input_width, input_depth}); + Tensor input_data(DT_FLOAT, input_shape); + test::FillIota(&input_data, 1.0f); + Output input = + ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); + AddOutputShape(input.node(), input_shape); + + TensorShape filter_shape( + {filter_size, filter_size, input_depth, filter_count}); + Tensor filter_data(DT_FLOAT, filter_shape); + test::FillIota(&filter_data, 1.0f); + Output filter = + ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); + AddOutputShape(filter.node(), filter_shape); + + Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter, + {1, stride, stride, 1}, padding); + AddOutputShape(conv.node(), input_shape); + return conv; + } +}; + +TEST_F(LayoutOptimizerTest, FilterSizeIsOne) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 1, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, FilterSizeNotOne) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 1, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 2, "VALID"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 2, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_TRUE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 3, "VALID"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_TRUE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index abce506aba2..2776b95a3cd 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1327,6 +1327,14 @@ cc_library( ], ) +cc_library( + name = "lookup", + deps = [ + ":lookup_table_init_op", + ":lookup_table_op", + ], +) + DATA_FLOW_DEPS = [ ":bounds_check", ":concat_lib", @@ -1450,10 +1458,10 @@ LOOKUP_DEPS = [ ":initializable_lookup_table", ":lookup_util", "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:lookup_ops_op_lib", ] tf_kernel_library( diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 1c7afcf8663..746fe63e2a0 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -19,9 +19,6 @@ limitations under the License. #include "tensorflow/core/kernels/crop_and_resize_op.h" -#include -#include - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -29,13 +26,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -43,67 +37,41 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -using Callback = std::function; -namespace { - -static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, - const Tensor& box_index, - int* num_boxes) { - if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { +static inline void ParseAndCheckBoxSizes(OpKernelContext* context, + const Tensor& boxes, + const Tensor& box_ind, + int* num_boxes) { + if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { *num_boxes = 0; - return Status::OK(); + return; } // The shape of 'boxes' is [num_boxes, 4]. - if (boxes.dims() != 2) { - return errors::InvalidArgument("boxes must be 2-D", - boxes.shape().DebugString()); - } + OP_REQUIRES(context, boxes.dims() == 2, + errors::InvalidArgument("boxes must be 2-D", + boxes.shape().DebugString())); *num_boxes = boxes.dim_size(0); - if (boxes.dim_size(1) != 4) { - return errors::InvalidArgument("boxes must have 4 columns"); - } - // The shape of 'box_index' is [num_boxes]. - if (box_index.dims() != 1) { - return errors::InvalidArgument("box_index must be 1-D", - box_index.shape().DebugString()); - } - if (box_index.dim_size(0) != *num_boxes) { - return errors::InvalidArgument("box_index has incompatible shape"); - } - return Status::OK(); + OP_REQUIRES(context, boxes.dim_size(1) == 4, + errors::InvalidArgument("boxes must have 4 columns")); + + // The shape of 'box_ind' is [num_boxes]. + OP_REQUIRES(context, box_ind.dims() == 1, + errors::InvalidArgument("box_ind must be 1-D", + box_ind.shape().DebugString())); + OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, + errors::InvalidArgument("box_ind has incompatible shape")); } -// Conditionally calls the compute callback if all values in box_index are in -// [0, batch_size) then calls done. +// Verifies that all values in box_ind are in [0, batch). template -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done); - -// Specialization of CheckValidBoxIndex for a CPUDevice. -template <> -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { - const int num_boxes = box_index.dimension(0); - for (int b = 0; b < num_boxes; ++b) { - OP_REQUIRES_ASYNC( - context, FastBoundsCheck(box_index(b), batch_size), - errors::OutOfRange("box_index has values outside [0, batch_size)"), - done); - } - compute(); - done(); -} - -} // namespace +inline void CheckValidBoxInd( + OpKernelContext* context, + typename TTypes::ConstTensor box_ind_data, int batch); template -class CropAndResizeOp : public AsyncOpKernel { +class CropAndResizeOp : public OpKernel { public: - explicit CropAndResizeOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", @@ -112,77 +80,69 @@ class CropAndResizeOp : public AsyncOpKernel { &extrapolation_value_)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - // The shape of 'image' is [batch_size, image_height, image_width, - // channels]. + void Compute(OpKernelContext* context) override { + // The shape of 'image' is [batch, image_height, image_width, channels]. const Tensor& image = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(1); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(2); - // The shape of 'crop_size' is [2]. - const Tensor& crop_size = context->input(3); + OP_REQUIRES(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString())); - // Validate inputs dimensions. - OP_REQUIRES_ASYNC(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString()), - done); - const int batch_size = image.dim_size(0); + const int batch = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); const int depth = image.dim_size(3); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(1); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(2); + int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - OP_REQUIRES_ASYNC(context, crop_size.dims() == 1, - errors::InvalidArgument("crop_size must be 1-D", - crop_size.shape().DebugString()), - done); - OP_REQUIRES_ASYNC( - context, crop_size.dim_size(0) == 2, - errors::InvalidArgument("crop_size must have two elements", - crop_size.shape().DebugString()), - done); + // The shape of 'crop_size' is [2]. + const Tensor& crop_size = context->input(3); + + OP_REQUIRES(context, crop_size.dims() == 1, + errors::InvalidArgument("crop_size must be 1-D", + crop_size.shape().DebugString())); + OP_REQUIRES(context, crop_size.dim_size(0) == 2, + errors::InvalidArgument("crop_size must have two elements", + crop_size.shape().DebugString())); - // Copy and validate crop sizes. auto crop_size_vec = crop_size.vec(); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("crop dimensions must be positive"), done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("crop dimensions must be positive")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( + OP_REQUIRES_OK( context, context->allocate_output( 0, TensorShape({num_boxes, crop_height, crop_width, depth}), - &output), - done); + &output)); - auto compute_callback = [this, context, output]() { - const Tensor& image = context->input(0); - const Tensor& boxes = context->input(1); - const Tensor& box_index = context->input(2); - const bool status = functor::CropAndResize()( - context->eigen_device(), image.tensor(), - boxes.tensor(), box_index.tensor(), - extrapolation_value_, output->tensor()); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeKernel.")); - } - }; + typename TTypes::ConstTensor image_data = image.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor crops_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResize()( + context->eigen_device(), image_data, boxes_data, box_ind_data, + extrapolation_value_, crops_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } } private: @@ -195,10 +155,10 @@ template struct CropAndResize { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, float extrapolation_value, typename TTypes::Tensor crops) { - const int batch_size = image.dimension(0); + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -213,8 +173,8 @@ struct CropAndResize { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -275,94 +235,89 @@ struct CropAndResize { return true; } }; - } // namespace functor template -class CropAndResizeGradImageOp : public AsyncOpKernel { +class CropAndResizeGradImageOp : public OpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(1); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(2); - // The shape of 'image_size' is [4]. - const Tensor& image_size = context->input(3); - // Validate input shapes. - OP_REQUIRES_ASYNC(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString()), - done); + OP_REQUIRES(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString())); const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive"), done); - int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); - OP_REQUIRES_ASYNC( - context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape"), - done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(1); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(2); + + int num_boxes = 0; + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); + + OP_REQUIRES( + context, grads.dim_size(0) == num_boxes, + errors::InvalidArgument("boxes and grads have incompatible shape")); + + // The shape of 'image_size' is [4]. + const Tensor& image_size = context->input(3); + OP_REQUIRES(context, image_size.dims() == 1, + errors::InvalidArgument("image_size must be 1-D", + image_size.shape().DebugString())); + OP_REQUIRES(context, image_size.dim_size(0) == 4, + errors::InvalidArgument("image_size must have 4 elements", + image_size.shape().DebugString())); - OP_REQUIRES_ASYNC(context, image_size.dims() == 1, - errors::InvalidArgument("image_size must be 1-D", - image_size.shape().DebugString()), - done); - OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4, - errors::InvalidArgument("image_size must have 4 elements", - image_size.shape().DebugString()), - done); auto image_size_vec = image_size.vec(); - const int batch_size = internal::SubtleMustCopy(image_size_vec(0)); + const int batch = internal::SubtleMustCopy(image_size_vec(0)); const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int depth = internal::SubtleMustCopy(image_size_vec(3)); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); - OP_REQUIRES_ASYNC( + + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + OP_REQUIRES( context, grads.dim_size(3) == depth, - errors::InvalidArgument("image_size and grads are incompatible"), done); + errors::InvalidArgument("image_size and grads are incompatible")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output( - 0, TensorShape({batch_size, image_height, image_width, depth}), - &output), - done); + OP_REQUIRES_OK( + context, context->allocate_output( + 0, TensorShape({batch, image_height, image_width, depth}), + &output)); - auto compute_callback = [context, output]() { - const Tensor& grads = context->input(0); - const Tensor& boxes = context->input(1); - const Tensor& box_index = context->input(2); - const bool status = functor::CropAndResizeBackpropImage()( - context->eigen_device(), grads.tensor(), - boxes.tensor(), box_index.tensor(), - output->tensor()); - if (!status) { - context->SetStatus(errors::Internal( - "Failed launch CropAndResizeBackpropImage kernel.")); - } - }; + typename TTypes::ConstTensor grads_data = + grads.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor output_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResizeBackpropImage()( + context->eigen_device(), grads_data, boxes_data, box_ind_data, + output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); + } } }; @@ -373,9 +328,9 @@ struct CropAndResizeBackpropImage { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_image) { - const int batch_size = grads_image.dimension(0); + const int batch = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -392,8 +347,8 @@ struct CropAndResizeBackpropImage { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -444,90 +399,83 @@ struct CropAndResizeBackpropImage { return true; } }; - } // namespace functor template -class CropAndResizeGradBoxesOp : public AsyncOpKernel { +class CropAndResizeGradBoxesOp : public OpKernel { public: explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(2); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(3); - // The shape of 'image' is [batch_size, image_height, image_width, depth]. - const Tensor& image = context->input(1); - // Validate input shapes. - OP_REQUIRES_ASYNC(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString()), - done); + OP_REQUIRES(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString())); + const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); const int depth = grads.dim_size(3); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive"), done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive")); - OP_REQUIRES_ASYNC(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString()), - done); - const int batch_size = image.dim_size(0); + // The shape of 'image' is [batch, image_height, image_width, depth]. + const Tensor& image = context->input(1); + OP_REQUIRES(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString())); + + const int batch = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); - OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth, - errors::InvalidArgument("image, grads depth differ"), - done); + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + OP_REQUIRES(context, image.dim_size(3) == depth, + errors::InvalidArgument("image, grads depth differ")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(2); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(3); int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - OP_REQUIRES_ASYNC( + OP_REQUIRES( context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape"), - done); + errors::InvalidArgument("boxes and grads have incompatible shape")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output(0, TensorShape({num_boxes, 4}), &output), - done); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_boxes, 4}), &output)); - auto compute_callback = [context, output]() { - const Tensor& grads = context->input(0); - const Tensor& image = context->input(1); - const Tensor& boxes = context->input(2); - const Tensor& box_index = context->input(3); - const bool status = functor::CropAndResizeBackpropBoxes()( - context->eigen_device(), grads.tensor(), - image.tensor(), boxes.tensor(), - box_index.tensor(), output->tensor()); - if (!status) { - context->SetStatus(errors::Internal( - "Failed launch CropAndResizeBackpropBoxes kernel.")); - } - }; + typename TTypes::ConstTensor grads_data = + grads.tensor(); + typename TTypes::ConstTensor image_data = image.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor output_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResizeBackpropBoxes()( + context->eigen_device(), grads_data, image_data, boxes_data, + box_ind_data, output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); + } } }; @@ -539,9 +487,9 @@ struct CropAndResizeBackpropBoxes { typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_boxes) { - const int batch_size = image.dimension(0); + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -558,8 +506,8 @@ struct CropAndResizeBackpropBoxes { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -641,19 +589,30 @@ struct CropAndResizeBackpropBoxes { return true; } }; - } // namespace functor -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp); \ - \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +// Specialization of CheckValidBoxInd for a CPUDevice. +template <> +inline void CheckValidBoxInd( + OpKernelContext* context, typename TTypes::ConstTensor box_ind, + int batch) { + const int num_boxes = box_ind.dimension(0); + for (int b = 0; b < num_boxes; ++b) { + OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, + errors::OutOfRange("box_ind has values outside [0, batch)")); + } +} + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); @@ -675,86 +634,50 @@ TF_CALL_double(REGISTER_KERNEL); #if GOOGLE_CUDA -// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU. +// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. namespace functor { template <> -void CheckValidBoxIndexHelper::operator()( - const GPUDevice& d, typename TTypes::ConstTensor box_index, - int batch_size, typename TTypes::Tensor isvalid); -extern template struct CheckValidBoxIndexHelper; +void CheckValidBoxIndHelper::operator()( + const GPUDevice& d, typename TTypes::ConstTensor box_ind, + int batch, typename TTypes::Tensor isvalid); +extern template struct CheckValidBoxIndHelper; } // namespace functor -namespace { - -// Specialization of CheckValidBoxIndex for a GPUDevice. +// Specialization of CheckValidBoxInd for a GPUDevice. template <> -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { - const int num_boxes = box_index.dimension(0); +inline void CheckValidBoxInd( + OpKernelContext* context, typename TTypes::ConstTensor box_ind, + int batch) { + const int num_boxes = box_ind.dimension(0); if (num_boxes == 0) { - compute(); - done(); return; } + Tensor isvalid_tensor; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({}), &isvalid_tensor)); - Tensor isvalid_dev_tensor; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_temp(DataTypeToEnum::value, TensorShape({}), - &isvalid_dev_tensor), - done); - typename TTypes::Tensor isvalid_dev = - isvalid_dev_tensor.tensor(); + typename TTypes::Tensor isvalid = isvalid_tensor.tensor(); - // Run the actual box check on the device. - functor::CheckValidBoxIndexHelper()( - context->eigen_device(), box_index, batch_size, isvalid_dev); + functor::CheckValidBoxIndHelper()( + context->eigen_device(), box_ind, batch, isvalid); - // Copy the result back to the host. auto* stream = context->op_device_context()->stream(); - OP_REQUIRES_ASYNC(context, stream, - errors::Internal("No GPU stream available."), done); - Tensor isvalid_host_tensor; - // Use pinned host memory on the host to avoid unnecessary - // synchronization. - AllocatorAttributes alloc_attr; - alloc_attr.set_on_host(true); - alloc_attr.set_gpu_compatible(true); - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_temp(DataTypeToEnum::value, TensorShape({}), - &isvalid_host_tensor, alloc_attr), - done); - typename TTypes::Tensor isvalid_host = - isvalid_host_tensor.tensor(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(), - sizeof(bool)); - const bool status = stream - ->ThenMemcpy(isvalid_host.data() /* destination */, - wrapped /* source */, sizeof(bool)) - .ok(); - OP_REQUIRES_ASYNC( - context, status, - errors::Internal("Failed to launch copy of isvalid from device to host."), - done); + bool isvalid_host = false; + perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), + sizeof(bool)); + stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); + stream->BlockHostUntilDone(); - auto wrapped_callback = [context, isvalid_host, compute, done]() { - OP_REQUIRES_ASYNC( - context, isvalid_host(), - errors::OutOfRange("box_index has values outside [0, batch_size)"), - done); - compute(); - done(); - }; + OP_REQUIRES(context, stream->ok(), + errors::Internal("cudaMemcpy from device to host failed")); - context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( - stream, wrapped_callback); + OP_REQUIRES(context, isvalid_host, + errors::OutOfRange("box_ind has values outside [0, batch)")); } -} // namespace - #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ .Device(DEVICE_GPU) \ diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 460dbad22b4..22df1bdd56b 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes { }; template -struct CheckValidBoxIndexHelper { - // Checks if all values in box_index are in [0, batch). +struct CheckValidBoxIndHelper { + // Checks if all values in box_ind are in [0, batch). void operator()(const Device& d, - typename TTypes::ConstTensor box_index, int batch, + typename TTypes::ConstTensor box_ind, int batch, typename TTypes::Tensor isvalid) { - isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); + isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index c1235fda892..254475db465 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS -template struct CheckValidBoxIndexHelper; +template struct CheckValidBoxIndHelper; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index d6139dae966..3a7f180598e 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( - StringPiece(s.ToString()).contains("box_index has incompatible shape")) + StringPiece(s.ToString()).contains("box_ind has incompatible shape")) << s; } @@ -264,10 +264,8 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("box_index has values outside [0, batch_size)")) + .contains("box_ind has values outside [0, batch)")) << s; } -// TODO(zhengxq, rmlarsen): Add a benchmark. - } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc index b5093d59fc0..48f38872e25 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc @@ -47,16 +47,26 @@ class SparseTensorDenseAddOp : public OpKernel { "Input a_indices should be a matrix but received shape: ", a_indices_t->shape().DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(a_values_t->shape()) && - TensorShapeUtils::IsVector(a_shape_t->shape()), + ctx, + TensorShapeUtils::IsVector(a_values_t->shape()) && + TensorShapeUtils::IsVector(a_shape_t->shape()), errors::InvalidArgument("Inputs a_values and a_shape should be vectors " "but received shapes: ", a_values_t->shape().DebugString(), " and ", a_shape_t->shape().DebugString())); - OP_REQUIRES(ctx, a_shape_t->NumElements() == b->dims(), - errors::InvalidArgument( - "Two operands have different dimensions; received: ", - a_shape_t->NumElements(), " and ", b->dims())); + OP_REQUIRES( + ctx, a_shape_t->NumElements() == b->dims(), + errors::InvalidArgument("Two operands have different ranks; received: ", + a_shape_t->NumElements(), " and ", b->dims())); + const auto a_shape_flat = a_shape_t->flat(); + for (int i = 0; i < b->dims(); ++i) { + OP_REQUIRES( + ctx, a_shape_flat(i) == b->dim_size(i), + errors::InvalidArgument( + "Dimension ", i, + " does not equal (no broadcasting is supported): sparse side ", + a_shape_flat(i), " vs dense side ", b->dim_size(i))); + } Tensor *out_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t)); @@ -82,8 +92,9 @@ class SparseTensorDenseAddOp : public OpKernel { NDIMS_CASE(4); NDIMS_CASE(5); default: - OP_REQUIRES(ctx, false, errors::InvalidArgument( - "Only tensors with ranks between 1 and 5 " + OP_REQUIRES( + ctx, false, + errors::InvalidArgument("Only tensors with ranks between 1 and 5 " "are currently supported. Tensor rank: ", ndims)); #undef NDIMS_CASE diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 30026f222a6..30c57ef287f 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()), errors::InvalidArgument("Tensor 'a_indices' is not a matrix")); - OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(), + const int64 nnz = a_indices->shape().dim_size(0); + OP_REQUIRES(ctx, nnz == a_values->NumElements(), errors::InvalidArgument("Number of rows of a_indices does not " "match number of entries in a_values")); @@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel { inner_left, " vs. ", inner_right, ". Did you forget a transpose? " "Dimensions of A: [", - a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ", - b->shape().DebugString())); + a_shape_t(0), ", ", a_shape_t(1), + "). Dimensions of B: ", b->shape().DebugString())); + + if (std::is_same::value) { + // The GPU implementation is optimized to use 32 bit indexing, so + // give a friendly error to the programmer early on if they + // exceed. + const int int32max = std::numeric_limits::max(); + OP_REQUIRES( + ctx, + (FastBoundsCheck(inner_left, int32max) && + FastBoundsCheck(inner_right, int32max) && + FastBoundsCheck(outer_left, int32max) && + FastBoundsCheck(outer_right, int32max) && + FastBoundsCheck(b->NumElements(), int32max) && + FastBoundsCheck(outer_left * outer_right, int32max) && + FastBoundsCheck(a_values->NumElements(), int32max)), + errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); + OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max), + errors::InvalidArgument( + "Cannot use GPU when output.shape[1] * nnz(a) > 2^31")); + } TensorShape out_shape({outer_left, outer_right}); Tensor* out = nullptr; @@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel { return; } - Tensor scratch; - - if (std::is_same::value) { - // The GPU implementation is optimized to use 32 bit indexing, so - // give a friendly error to the programmer early on if they exceed. - OP_REQUIRES( - ctx, - FastBoundsCheck(inner_left, std::numeric_limits::max()) && - FastBoundsCheck(inner_right, std::numeric_limits::max()) && - FastBoundsCheck(outer_left, std::numeric_limits::max()) && - FastBoundsCheck(outer_right, std::numeric_limits::max()) && - FastBoundsCheck(b->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(out->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(a_values->NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); - const int nnz = static_cast(a_values->NumElements()); - // Need nnz length vec scratch space on the GPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({nnz}), &scratch)); - } else { - // We don't need scratch space on the CPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({0}), &scratch)); - } - #define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ Device, T, Tindices, ADJ_A, \ ADJ_B>::Compute(ctx->eigen_device(), out->matrix(), \ a_indices->matrix(), a_values->vec(), \ - b->matrix(), scratch.vec()); \ + b->matrix()); \ OP_REQUIRES_OK(ctx, functor_status); \ } @@ -189,10 +182,9 @@ namespace functor { Status SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, \ ADJ_B>::Compute(const GPUDevice& d, typename TTypes::Matrix out, \ - typename TTypes::ConstMatrix a_indices, \ + TTypes::ConstMatrix a_indices, \ typename TTypes::ConstVec a_values, \ - typename TTypes::ConstMatrix b, \ - typename TTypes::Vec scratch); \ + typename TTypes::ConstMatrix b); \ extern template struct SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, ADJ_B>; @@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor { static Status Compute(const CPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, - typename TTypes::Vec scratch) { + typename TTypes::ConstMatrix b) { const std::size_t nnz = a_values.size(); const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index e707743f782..da131904949 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -28,11 +28,10 @@ namespace functor { template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status - Compute(const Device& d, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch); + static EIGEN_ALWAYS_INLINE Status Compute( + const Device& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); }; template diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc index 7266e0cf812..e261e42e0d3 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -20,71 +20,45 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -namespace generator { - template -class SparseTensorDenseMatMulGPUGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator( - typename TTypes::Tensor32Bit out, - typename TTypes::Tensor32Bit a_indices, - typename TTypes::Tensor32Bit a_values, - typename TTypes::Tensor32Bit b) - : out_(out), - lhs_index_a_(ADJ_A ? 1 : 0), - rhs_index_a_(ADJ_A ? 0 : 1), - a_indices_(a_indices), - a_values_(a_values), - lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)), - maybe_adjoint_b_( - functor::MaybeAdjoint::Tensor32Bit, - ADJ_B>(b)) {} - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T - operator()(const Eigen::array& j_and_ix) const { -#ifdef __CUDA_ARCH__ - const int j = j_and_ix[0]; - const int ix = j_and_ix[1]; - int m = a_indices_(ix, lhs_index_a_); - int k = a_indices_(ix, rhs_index_a_); - assert(k < lhs_right_size); - assert(m < out_.dimension(0)); - // If asserts are disabled, the caller is violating the sparse - // tensor index contract, and so we return invalid results. - // Force returning NaNs to try to signal that something is amiss. - T b_value; - if (k >= lhs_right_size || m >= out_.dimension(0)) { - m = 0; - k = 0; - b_value = std::numeric_limits::quiet_NaN(); - } else { - b_value = maybe_adjoint_b_(k, j); +__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows, + int b_cols, int p, + const Tindices* a_indices, + const T* a_values, const T* b, + T* out) { + // out_{ij} = sum_k {a_ik b_kj} + // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk} + const int n = (ADJ_B) ? b_cols : b_rows; + CUDA_1D_KERNEL_LOOP(index, nnz * p) { + const int a_ix = index / p; + const int j = index % p; + const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0)); + const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1)); + if (!FastBoundsCheck(i, m)) { + continue; // Nowhere to signal an error :( } - atomicAdd(&out_(m, j), a_values_(ix) * b_value); -#else - assert(false && "This should only be run on the device"); -#endif - // Return something - return T(0); + // out[i, j] + T* out_location = out + i * p + j; + if (!FastBoundsCheck(k, n)) { + CudaAtomicAdd(out_location, std::numeric_limits::quiet_NaN()); + continue; + } + + // a_value == (ADJ_A) ? a[k, i] : a[i, k] + const T a_value = ldg(a_values + a_ix); + + // b_value == (ADJ_B) ? b[j, k] : b[k, j] + const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j)); + CudaAtomicAdd(out_location, a_value * b_value); } - - private: - mutable typename TTypes::Tensor32Bit out_; - const int lhs_index_a_; - const int rhs_index_a_; - typename TTypes::Tensor32Bit a_indices_; - typename TTypes::Tensor32Bit a_values_; - const int lhs_right_size; - functor::MaybeAdjoint::Tensor32Bit, ADJ_B> - maybe_adjoint_b_; -}; - -} // namespace generator +} namespace functor { @@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor { Compute(const GPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch) { - generator::SparseTensorDenseMatMulGPUGenerator - sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices), - To32Bit(a_values), To32Bit(b)); - To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + typename TTypes::ConstMatrix b) { + out.device(d) = out.constant(T(0)); int nnz = a_values.size(); - int n = (ADJ_B) ? b.dimension(0) : b.dimension(1); + // out = A * B, A is [m x n] and B is [n x p], out is [m x p] + int m = out.dimension(0); + int p = out.dimension(1); + int b_rows = b.dimension(0); + int b_cols = b.dimension(1); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions matrix_1_by_nnz{{ 1, nnz }}; - Eigen::array n_by_1{{ n, 1 }}; - Eigen::array reduce_on_rows{{ 0 }}; -#else - Eigen::IndexList, int> matrix_1_by_nnz; - matrix_1_by_nnz.set(1, nnz); - Eigen::IndexList > n_by_1; - n_by_1.set(0, n); - Eigen::IndexList > reduce_on_rows; -#endif + // TODO(ebrevdo): Should this be alpha * nnz instead of + // out.size()? Perhaps p * nnz ? + CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d); - // How this works: the generator iterates over (j, ix) where j - // iterates from 0 .. n - 1 and ix iterates from - // 0 .. nnz - 1. A side effect of the generator is to accumulate - // the products of values in A and B into the appropriate location - // in the dense matrix out. In order to run the iteration, - // we take a smaller variable and broadcast to a size (n, nnz). - // This is the scratch variable. In order to enforce execution, - // we have to perform assignment back into scratch (taking the sum). - // We don't care what gets assigned to scratch - only the side effect - // of the execution in the generator. - // - // Note it's not sufficient that scratch be a scalar, and to - // broadcast it to a matrix. Eigen splits the computation not - // based on the largest intermediate shape (the size of the - // broadcast of scratch) but based on the output shape. So - // scratch needs to be a vector at least. - // - // Note also that only float type is supported because the - // atomicAdd operation is only supported for floats in hardware. - To32Bit(scratch).device(d) = - To32Bit(scratch) - .reshape(matrix_1_by_nnz) - .broadcast(n_by_1) - .generate(sparse_tensor_dense_matmul_generator) - .sum(reduce_on_rows); + SparseTensorDenseMatMulKernel + <<>>( + nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(), + b.data(), out.data()); return Status::OK(); } diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index f5d4fcec84c..d50e2060acf 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "tensorflow/core/framework/op_kernel.h" @@ -21,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -50,8 +50,7 @@ class UniqueOp : public OpKernel { {0}, 1, input.shape(), &idx)); auto idx_vec = idx->template vec(); - std::unordered_map uniq; - uniq.reserve(2 * N); + gtl::FlatMap uniq(N); for (int64 i = 0, j = 0; i < N; ++i) { auto it = uniq.insert(std::make_pair(Tin(i), j)); idx_vec(i) = it.first->second; diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 8c173a4ba30..25b17b26c8d 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -76,6 +76,18 @@ class VariableOp : public OpKernel { // As long as the resource manager hasn't been cleared the ref we return // here is valid because it owns a ref on var. ctx->set_output_ref(0, var->mu(), var->tensor()); + if (ctx->track_allocations() && var->tensor()->IsInitialized()) { + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + if (ctx->allocate_on_host(attr)) { + ctx->record_host_persistent_memory_allocation( + var->tensor()->AllocatedBytes()); + } else { + ctx->record_device_persistent_memory_allocation( + var->tensor()->AllocatedBytes()); + } + } var->Unref(); } diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f35a1bb6489..032ede6459c 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1876,604 +1876,6 @@ size: The number of incomplete elements (i.e. those with some of their value // -------------------------------------------------------------------------- -REGISTER_OP("LookupTableFind") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableFindV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableInsert") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableInsertV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableSize") - .Input("table_handle: Ref(string)") - .Output("size: int64") - .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableSizeV2") - .Input("table_handle: resource") - .Output("size: int64") - .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableExport") - .Input("table_handle: Ref(string)") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableExportV2") - .Input("table_handle: resource") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableImport") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableImportV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("HashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("HashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensors") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensorsV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableDenseHashTable") - .Input("empty_key: key_dtype") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("MutableDenseHashTableV2") - .Input("empty_key: key_dtype") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("InitializeTable") - .Input("table_handle: Ref(string)") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableV2") - .Input("table_handle: resource") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableFromTextFile") - .Input("table_handle: Ref(string)") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - -REGISTER_OP("InitializeTableFromTextFileV2") - .Input("table_handle: resource") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - REGISTER_OP("GetSessionHandle") .Input("value: T") .Output("handle: string") diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc new file mode 100644 index 00000000000..498a65690d0 --- /dev/null +++ b/tensorflow/core/ops/lookup_ops.cc @@ -0,0 +1,666 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- + +namespace { +Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + for (int i = 0; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status TwoElementOutput(InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); +} + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} +} // namespace + +REGISTER_OP("LookupTableFind") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableFindV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableInsert") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableInsertV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableSize") + .Input("table_handle: Ref(string)") + .Output("size: int64") + .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableSizeV2") + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableExport") + .Input("table_handle: Ref(string)") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableExportV2") + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableImport") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableImportV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("HashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("HashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensors") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensorsV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableDenseHashTable") + .Input("empty_key: key_dtype") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("MutableDenseHashTableV2") + .Input("empty_key: key_dtype") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("InitializeTable") + .Input("table_handle: Ref(string)") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableV2") + .Input("table_handle: resource") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableFromTextFile") + .Input("table_handle: Ref(string)") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +REGISTER_OP("InitializeTableFromTextFileV2") + .Input("table_handle: resource") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/cluster.proto b/tensorflow/core/protobuf/cluster.proto new file mode 100644 index 00000000000..33c87eefe02 --- /dev/null +++ b/tensorflow/core/protobuf/cluster.proto @@ -0,0 +1,82 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ClusterProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster. +// +// EXAMPLES +// -------- +// +// 1. A single-process cluster, containing "/job:local/task:0". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } +// +// Server: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// +// 2. A two-process cluster, containing "/job:local/task:{0,1}". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } +// tasks { key: 1 value: 'localhost:2223' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// cluster { $CLUSTER } job_name: 'local' task_index: 1 +// +// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and +// "/job:ps/task:{0,1}". +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } +// tasks { key: 2 value: 'worker3:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'worker' task_index: 0 +// cluster { $CLUSTER } job_name: 'worker' task_index: 1 +// cluster { $CLUSTER } job_name: 'worker' task_index: 2 +// cluster { $CLUSTER } job_name: 'ps' task_index: 0 +// cluster { $CLUSTER } job_name: 'ps' task_index: 1 + +// Defines a single job in a TensorFlow cluster. +message JobDef { + // The name of this job. + string name = 1; + + // Mapping from task ID to "hostname:port" string. + // + // If the `name` field contains "worker", and the `tasks` map contains a + // mapping from 7 to "example.org:2222", then the device prefix + // "/job:worker/task:7" will be assigned to "example.org:2222". + map tasks = 2; +} + +// Defines a TensorFlow cluster as a set of jobs. +message ClusterDef { + // The jobs that comprise the cluster. + repeated JobDef job = 1; +} diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 5c0f7232ebd..630f47633f8 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/protobuf/debug.proto"; +import "tensorflow/core/protobuf/cluster.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; message GPUOptions { @@ -259,6 +260,11 @@ message ConfigProto { // Options that apply when this session uses the distributed runtime. RPCOptions rpc_options = 13; + + // Optional list of all workers to use in this session. + ClusterDef cluster_def = 14; + + // Next: 15 }; // Options for a single Run() call. diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto index de91b6133e4..e607b1c42a5 100644 --- a/tensorflow/core/protobuf/master.proto +++ b/tensorflow/core/protobuf/master.proto @@ -38,6 +38,9 @@ message CreateSessionRequest { // Configuration options. ConfigProto config = 2; + + // The target string used from the client's perspective. + string target = 3; } message CreateSessionResponse { diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index c4077bd98e4..6199e707e5a 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/cluster.proto"; package tensorflow; option cc_enable_arenas = true; @@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; -// This file contains protos to be used when defining a TensorFlow -// cluster, and a server within that cluster. -// -// EXAMPLES -// -------- -// -// 1. A single-process cluster, containing "/job:local/task:0". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } -// -// Server: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// -// 2. A two-process cluster, containing "/job:local/task:{0,1}". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } -// tasks { key: 1 value: 'localhost:2223' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// cluster { $CLUSTER } job_name: 'local' task_index: 1 -// -// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and -// "/job:ps/task:{0,1}". -// -// Cluster: -// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } -// tasks { key: 1 value: 'worker2:2222' } -// tasks { key: 2 value: 'worker3:2222' } } -// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } -// tasks { key: 1 value: 'ps1:2222' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'worker' task_index: 0 -// cluster { $CLUSTER } job_name: 'worker' task_index: 1 -// cluster { $CLUSTER } job_name: 'worker' task_index: 2 -// cluster { $CLUSTER } job_name: 'ps' task_index: 0 -// cluster { $CLUSTER } job_name: 'ps' task_index: 1 - -// Defines a single job in a TensorFlow cluster. -message JobDef { - // The name of this job. - string name = 1; - - // Mapping from task ID to "hostname:port" string. - // - // If the `name` field contains "worker", and the `tasks` map contains a - // mapping from 7 to "example.org:2222", then the device prefix - // "/job:worker/task:7" will be assigned to "example.org:2222". - // - // NOTE(mrry): Currently, only a dense task ID space starting at 0 is - // supported. - map tasks = 2; -} - -// Defines a TensorFlow cluster as a set of jobs. -message ClusterDef { - // The jobs that comprise the cluster. - repeated JobDef job = 1; -} - // Defines the configuration of a single TensorFlow server. message ServerDef { // The cluster of which this server is a member. diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 661327847c1..cf05aece39a 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -119,6 +119,10 @@ message RegisterGraphResponse { //////////////////////////////////////////////////////////////////////////////// message DeregisterGraphRequest { + // The session_handle used when registering the graph. If session_handle is + // empty, a single global namespace is used. + string session_handle = 2; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -167,6 +171,12 @@ message ExecutorOpts { }; message RunGraphRequest { + // session_handle is the the master-generated unique id for this session. + // If session_handle is non-empty, it must be the same as used when + // registering the graph. If it is empty, a single global namespace is used to + // search for the graph_handle. + string session_handle = 8; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -193,6 +203,8 @@ message RunGraphRequest { bool is_partial = 6; // True if this is the last partial run request in a sequence of requests. bool is_last_partial_run = 7; + + // Next: 9 } message RunGraphResponse { diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md index b52adc3790a..00cc10cd347 100644 --- a/tensorflow/docs_src/get_started/get_started.md +++ b/tensorflow/docs_src/get_started/get_started.md @@ -372,25 +372,36 @@ features = [tf.contrib.layers.real_valued_column("x", dimension=1)] estimator = tf.contrib.learn.LinearRegressor(feature_columns=features) # TensorFlow provides many helper methods to read and set up data sets. -# Here we use `numpy_input_fn`. We have to tell the function how many batches +# Here we use two data sets: one for training and one for evaluation +# We have to tell the function how many batches # of data (num_epochs) we want and how big each batch should be. -x = np.array([1., 2., 3., 4.]) -y = np.array([0., -1., -2., -3.]) -input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x}, y, batch_size=4, +x_train = np.array([1., 2., 3., 4.]) +y_train = np.array([0., -1., -2., -3.]) +x_eval = np.array([2., 5., 8., 1.]) +y_eval = np.array([-1.01, -4.1, -7, 0.]) +input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train, + batch_size=4, num_epochs=1000) +eval_input_fn = tf.contrib.learn.io.numpy_input_fn( + {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000) -# We can invoke 1000 training steps by invoking the `fit` method and passing the +# We can invoke 1000 training steps by invoking the method and passing the # training data set. estimator.fit(input_fn=input_fn, steps=1000) -# Here we evaluate how well our model did. In a real example, we would want -# to use a separate validation and testing data set to avoid overfitting. -print(estimator.evaluate(input_fn=input_fn)) +# Here we evaluate how well our model did. +train_loss = estimator.evaluate(input_fn=input_fn) +eval_loss = estimator.evaluate(input_fn=eval_input_fn) +print("train loss: %r"% train_loss) +print("eval loss: %r"% eval_loss) ``` When run, it produces ``` - {'global_step': 1000, 'loss': 1.9650059e-11} + train loss: {'global_step': 1000, 'loss': 4.3049088e-08} + eval loss: {'global_step': 1000, 'loss': 0.0025487561} ``` +Notice how our eval data has a higher loss, but it is still close to zero. +That means we are learning properly. ### A custom model @@ -432,19 +443,25 @@ def model(features, labels, mode): train_op=train) estimator = tf.contrib.learn.Estimator(model_fn=model) -# define our data set -x = np.array([1., 2., 3., 4.]) -y = np.array([0., -1., -2., -3.]) -input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x}, y, 4, num_epochs=1000) +# define our data sets +x_train = np.array([1., 2., 3., 4.]) +y_train = np.array([0., -1., -2., -3.]) +x_eval = np.array([2., 5., 8., 1.]) +y_eval = np.array([-1.01, -4.1, -7, 0.]) +input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x_train}, y_train, 4, num_epochs=1000) # train estimator.fit(input_fn=input_fn, steps=1000) -# evaluate our model -print(estimator.evaluate(input_fn=input_fn, steps=10)) +# Here we evaluate how well our model did. +train_loss = estimator.evaluate(input_fn=input_fn) +eval_loss = estimator.evaluate(input_fn=eval_input_fn) +print("train loss: %r"% train_loss) +print("eval loss: %r"% eval_loss) ``` When run, it produces -```python -{'loss': 5.9819476e-11, 'global_step': 1000} +``` +train loss: {'global_step': 1000, 'loss': 4.9380226e-11} +eval loss: {'global_step': 1000, 'loss': 0.01010081} ``` Notice how the contents of the custom `model()` function are very similar diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 5304779c004..55d9c2c08f3 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -218,11 +218,7 @@ and Mac OS X: And the following comand line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF
- -And the following comand line executes the `HelloTF` program on Windows: - -
java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.1.0.jar;. -Djava.library.path=jni HelloTF
If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index 309b39451fd..acdca2bad4f 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -39,6 +39,11 @@ trained graph. The following guide details `MetaGraph` objects: * @{$meta_graph$Exporting and Importing a MetaGraph}. +`SavedModel` is the universal serialization format for Tensorflow models. TensorFlow provides SavedModel CLI (command-line interface) as a tool to inspect and execute a MetaGraph in a SavedModel. The detailed usages and examples are +documented in the following guide: + + * @{$saved_model_cli$SavedModel CLI (Command-Line Interface)}. + To learn about the TensorFlow versioning scheme, consult the following two guides: diff --git a/tensorflow/docs_src/programmers_guide/supervisor.md b/tensorflow/docs_src/programmers_guide/supervisor.md index 82ed1c2cf76..55a090df589 100644 --- a/tensorflow/docs_src/programmers_guide/supervisor.md +++ b/tensorflow/docs_src/programmers_guide/supervisor.md @@ -362,8 +362,8 @@ following keyword arguments to the `Supervisor()` constructor: If not specified, the supervisor uses the first op in the `tf.GraphKeys.LOCAL_INIT_OP` collection. If the collection is empty the supervisor adds an op to initialize all the tables and local variables in - the graph by calling `tf.initialize_all_tables()` and - `tf.initialize_all_local_variables()`. + the graph by calling `tf.tables_initializer()` and + `tf.local_variables_initializer()`. Pass `None` to not use a local init op. diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java index b26a2316782..bc391269255 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java @@ -194,13 +194,12 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index 206a99f3e3d..5800f80651b 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -124,7 +124,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); - tracker = new MultiBoxTracker(getResources().getDisplayMetrics()); + tracker = new MultiBoxTracker(this); if (USE_YOLO) { detector = @@ -273,13 +273,12 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java index 7634be5c020..7afe2bf5412 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java @@ -65,10 +65,6 @@ import org.tensorflow.demo.R; * Artistic Style" (https://arxiv.org/abs/1610.07629) */ public class StylizeActivity extends CameraActivity implements OnImageAvailableListener { - static { - System.loadLibrary("tensorflow_demo"); - } - private static final Logger LOGGER = new Logger(); private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb"; @@ -509,17 +505,17 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL final int yRowStride = planes[0].getRowStride(); final int uvRowStride = planes[1].getRowStride(); final int uvPixelStride = planes[1].getPixelStride(); + ImageUtils.convertYUV420ToARGB8888( yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index f3e7114335f..1dcf9f55efe 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -41,10 +41,6 @@ import org.tensorflow.demo.env.Logger; public class TensorFlowMultiBoxDetector implements Classifier { private static final Logger LOGGER = new Logger(); - static { - System.loadLibrary("tensorflow_demo"); - } - // Only return this many results with at least this confidence. private static final int MAX_RESULTS = Integer.MAX_VALUE; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java index 174723071da..b7e36a2379d 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java @@ -31,10 +31,6 @@ import org.tensorflow.demo.env.SplitTimer; public class TensorFlowYoloDetector implements Classifier { private static final Logger LOGGER = new Logger(); - static { - System.loadLibrary("tensorflow_demo"); - } - // Only return this many results with at least this confidence. private static final int MAX_RESULTS = 5; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java index db929e5e087..5f2ff9164cc 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java @@ -27,6 +27,14 @@ import java.io.FileOutputStream; public class ImageUtils { @SuppressWarnings("unused") private static final Logger LOGGER = new Logger(); + + static { + try { + System.loadLibrary("tensorflow_demo"); + } catch (UnsatisfiedLinkError e) { + LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable."); + } + } /** * Utility method to compute the allocated size in bytes of a YUV420SP image @@ -83,10 +91,84 @@ public class ImageUtils { } } + // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges + // are normalized to eight bits. + static final int kMaxChannelValue = 262143; + + // Always prefer the native implementation if available. + private static boolean useNativeConversion = true; + + public static void convertYUV420ToARGB8888( + byte[] yData, + byte[] uData, + byte[] vData, + int width, + int height, + int yRowStride, + int uvRowStride, + int uvPixelStride, + int[] out) { + if (useNativeConversion) { + try { + convertYUV420ToARGB8888( + yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false); + return; + } catch (UnsatisfiedLinkError e) { + LOGGER.w("Native YUV -> RGB implementation not found, falling back to Java implementation"); + useNativeConversion = false; + } + } + + int i = 0; + for (int y = 0; y < height; y++) { + int pY = yRowStride * y; + int uv_row_start = uvRowStride * (y >> 1); + int pUV = uv_row_start; + int pV = uv_row_start; + + for (int x = 0; x < width; x++) { + int uv_offset = pUV + (x >> 1) * uvPixelStride; + out[i++] = + YUV2RGB( + convertByteToInt(yData, pY + x), + convertByteToInt(uData, uv_offset), + convertByteToInt(vData, uv_offset)); + } + } + } + + private static int convertByteToInt(byte[] arr, int pos) { + return arr[pos] & 0xFF; + } + + private static int YUV2RGB(int nY, int nU, int nV) { + nY -= 16; + nU -= 128; + nV -= 128; + if (nY < 0) nY = 0; + + // This is the floating point equivalent. We do the conversion in integer + // because some Android devices do not have floating point in hardware. + // nR = (int)(1.164 * nY + 2.018 * nU); + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); + // nB = (int)(1.164 * nY + 1.596 * nV); + + final int foo = 1192 * nY; + int nR = foo + 1634 * nV; + int nG = foo - 833 * nV - 400 * nU; + int nB = foo + 2066 * nU; + + nR = Math.min(kMaxChannelValue, Math.max(0, nR)); + nG = Math.min(kMaxChannelValue, Math.max(0, nG)); + nB = Math.min(kMaxChannelValue, Math.max(0, nB)); + + return 0xff000000 | ((nR << 6) & 0x00ff0000) | ((nG >> 2) & 0x0000FF00) | ((nB >> 10) & 0xff); + } + /** - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width - * and height. The input and output must already be allocated and non-null. - * For efficiency, no error checking is performed. + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The + * input and output must already be allocated and non-null. For efficiency, no error checking is + * performed. * * @param input The array of YUV 4:2:0 input data. * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index 49c91d600da..91d1f9feb18 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -15,6 +15,7 @@ limitations under the License. package org.tensorflow.demo.tracking; +import android.content.Context; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Matrix; @@ -24,9 +25,9 @@ import android.graphics.Paint.Join; import android.graphics.Paint.Style; import android.graphics.RectF; import android.text.TextUtils; -import android.util.DisplayMetrics; import android.util.Pair; import android.util.TypedValue; +import android.widget.Toast; import java.util.LinkedList; import java.util.List; import java.util.Queue; @@ -69,6 +70,7 @@ public class MultiBoxTracker { private static class TrackedRecognition { ObjectTracker.TrackedObject trackedObject; + RectF location; float detectionConfidence; int color; String title; @@ -87,8 +89,10 @@ public class MultiBoxTracker { private int frameHeight; private int sensorOrientation; + private Context context; - public MultiBoxTracker(final DisplayMetrics metrics) { + public MultiBoxTracker(final Context context) { + this.context = context; for (final int color : COLORS) { availableColors.add(color); } @@ -100,7 +104,9 @@ public class MultiBoxTracker { boxPaint.setStrokeJoin(Join.ROUND); boxPaint.setStrokeMiter(100); - textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics); + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); } @@ -152,10 +158,6 @@ public class MultiBoxTracker { } public synchronized void draw(final Canvas canvas) { - if (objectTracker == null) { - return; - } - // TODO(andrewharp): This may not work for non-90 deg rotations. final float multiplier = Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth); @@ -168,9 +170,11 @@ public class MultiBoxTracker { sensorOrientation, false); for (final TrackedRecognition recognition : trackedObjects) { - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + final RectF trackedPos = + (objectTracker != null) + ? recognition.trackedObject.getTrackedPositionInPreviewFrame() + : new RectF(recognition.location); - final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); getFrameToCanvasMatrix().mapRect(trackedPos); boxPaint.setColor(recognition.color); @@ -185,6 +189,8 @@ public class MultiBoxTracker { } } + private boolean initialized = false; + public synchronized void onFrame( final int w, final int h, @@ -192,7 +198,7 @@ public class MultiBoxTracker { final int sensorOrienation, final byte[] frame, final long timestamp) { - if (objectTracker == null) { + if (objectTracker == null && !initialized) { ObjectTracker.clearInstance(); logger.i("Initializing ObjectTracker: %dx%d", w, h); @@ -200,6 +206,19 @@ public class MultiBoxTracker { frameWidth = w; frameHeight = h; this.sensorOrientation = sensorOrienation; + initialized = true; + + if (objectTracker == null) { + String message = + "Object tracking support not found. " + + "See tensorflow/examples/android/README.md for details."; + Toast.makeText(context, message, Toast.LENGTH_LONG).show(); + logger.e(message); + } + } + + if (objectTracker == null) { + return; } objectTracker.nextFrame(frame, null, timestamp, null, true); @@ -255,7 +274,20 @@ public class MultiBoxTracker { } if (objectTracker == null) { - logger.w("No ObjectTracker, can't track anything!"); + trackedObjects.clear(); + for (final Pair potential : rectsToTrack) { + final TrackedRecognition trackedRecognition = new TrackedRecognition(); + trackedRecognition.detectionConfidence = potential.first; + trackedRecognition.location = new RectF(potential.second.getLocation()); + trackedRecognition.trackedObject = null; + trackedRecognition.title = potential.second.getTitle(); + trackedRecognition.color = COLORS[trackedObjects.size()]; + trackedObjects.add(trackedRecognition); + + if (trackedObjects.size() >= COLORS.length) { + break; + } + } return; } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java index 82de634baff..69f202b5681 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java @@ -48,7 +48,18 @@ import org.tensorflow.demo.env.Size; * ObjectTracker still exists. */ public class ObjectTracker { - private final Logger logger = new Logger(); + private static final Logger LOGGER = new Logger(); + + private static boolean libraryFound = false; + + static { + try { + System.loadLibrary("tensorflow_demo"); + libraryFound = true; + } catch (UnsatisfiedLinkError e) { + LOGGER.e("libtensorflow_demo.so not found, tracking unavailable"); + } + } private static final boolean DRAW_TEXT = false; @@ -194,6 +205,13 @@ public class ObjectTracker { public static synchronized ObjectTracker getInstance( final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { + if (!libraryFound) { + LOGGER.e( + "Native object tracking support not found. " + + "See tensorflow/examples/android/README.md for details."); + return null; + } + if (instance == null) { instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack); instance.init(); @@ -519,7 +537,7 @@ public class ObjectTracker { checkValidObject(); synchronized (ObjectTracker.this) { if (lastExternalPositionTime > timestamp) { - logger.w("Tried to use older position time!"); + LOGGER.w("Tried to use older position time!"); return; } final RectF externalPosition = downscaleRect(position); @@ -640,8 +658,4 @@ public class ObjectTracker { protected static native void downsampleImageNative( int width, int height, int rowStride, byte[] input, int factor, byte[] output); - - static { - System.loadLibrary("tensorflow_demo"); - } } diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md index e32c21ca720..a1b4255292b 100644 --- a/tensorflow/go/README.md +++ b/tensorflow/go/README.md @@ -9,24 +9,22 @@ Construct and execute TensorFlow graphs in Go. > (`github.com/tensorflow/tensorflow/tensorflow/go`). ## Quickstart - 1. Download and extract the TensorFlow C library, preferably into `/usr/local`. GPU-enabled versions require CUDA 8.0 and cuDNN 5.1. For other versions, the TensorFlow C library will have to be built from source (see below). - Linux: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.0.0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.0.0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.1.0.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.1.0.tar.gz) - OS X - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.0.0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-darwin-x86_64-1.0.0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.1.0.tar.gz), The following shell snippet downloads and extracts into `/usr/local`: ```sh TF_TYPE="cpu" # Set to "gpu" for GPU support curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.0.0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0.tar.gz" | sudo tar -C /usr/local -xz ``` @@ -41,20 +39,7 @@ Construct and execute TensorFlow graphs in Go. ### Installing into locations other than `/usr/local` -The TensorFlow C library (`libtensorflow.so`) needs to be available at build -time (e.g., `go build`) and run time (`go test` or executing binaries). If the -library has not been extracted into `/usr/local`, then it needs to be made -available through the `LIBRARY_PATH` environment variable at build time and the -`LD_LIBRARY_PATH` environment variable (`DYLD_LIBRARY_PATH` on OS X) at run -time. - -For example, if the TensorFlow C library was extracted into `/dir`, then: - -```sh -export LIBRARY_PATH=/dir/lib -export LD_LIBRARY_PATH=/dir/lib # For Linux -export DYLD_LIBRARY_PATH=/dir/lib # For OS X -``` +Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go) ## Building the TensorFlow C library from source diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c63be8bc5ee..eb4789a1829 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -3522,256 +3522,6 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera return scope.AddOperation(opspec) } -// Table initializer that takes two tensors for keys and values respectively. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// keys: Keys of type Tkey. -// values: Values of type Tval. -// -// Returns the created operation. -func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InitializeTableV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) - -// MutableHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) - -// HashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. -// -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "HashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableImportV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Outputs all keys and values in the table. -// -// Arguments: -// table_handle: Handle to the table. -// -// -// -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} - opspec := tf.OpSpec{ - Type: "LookupTableExportV2", - Input: []tf.Input{ - table_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Updates the table to associates keys with values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Looks up keys in a table, outputs the corresponding values. -// -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. -// -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// -// -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableFindV2", - Input: []tf.Input{ - table_handle, keys, default_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. type FakeQuantWithMinMaxArgsAttr func(optionalAttr) @@ -5404,6 +5154,435 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou return op.Output(0) } +// Draw bounding boxes on a batch of images. +// +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, if an image is 100 x 200 pixels and the bounding box is +// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the +// bounding box will be `(10, 40)` to `(50, 180)`. +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxes", + Input: []tf.Input{ + images, boxes, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert one or more images from HSV to RGB. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// See `rgb_to_hsv` for a description of the HSV encoding. +// +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HSVToRGB", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Decode the first frame of a GIF-encoded image to a uint8 tensor. +// +// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: +// +// convert $src.gif -coalesce $dst.gif +// +// Arguments: +// contents: 0-D. The GIF-encoded image. +// +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeGif", + Input: []tf.Input{ + contents, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) + +// DecodePngChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// Arguments: +// contents: 0-D. The PNG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodePng", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adjust the contrast of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` +// +// Contrast is adjusted independently for each channel of each image. +// +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. +// +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrastv2", + Input: []tf.Input{ + images, contrast_factor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeJpegAttr is an optional argument to DecodeJpeg. +type DecodeJpegAttr func(optionalAttr) + +// DecodeJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeJpegChannels(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeJpegRatio(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeJpegDctMethod(value string) DecodeJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeJpeg", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) + +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of nearest neighbor interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighborGrad", + Input: []tf.Input{ + grads, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) + +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Resize `images` to `size` using nearest neighbor interpolation. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighbor", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the set of files matching one or more glob patterns. +// +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. +// +// Arguments: +// pattern: Shell wildcard pattern(s). Scalar or vector of type string. +// +// Returns A vector of matching filenames. +func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatchingFiles", + Input: []tf.Input{ + pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Shuffle dimensions of x according to a permutation. +// +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Transpose", + Input: []tf.Input{ + x, perm, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes softmax cross entropy cost and gradients to backpropagate. // // Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept @@ -6560,6 +6739,95 @@ func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { return op.Output(0) } +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) + +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Resize `images` to `size` using bilinear interpolation. +// +// Input images can be of different types but output images are always float. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinear", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ProdAttr is an optional argument to Prod. +type ProdAttr func(optionalAttr) + +// ProdKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func ProdKeepDims(value bool) ProdAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the product of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `reduction_indices`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// reduction_indices: The dimensions to reduce. +// +// Returns The reduced tensor. +func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Prod", + Input: []tf.Input{ + input, reduction_indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. type DepthwiseConv2dNativeAttr func(optionalAttr) @@ -6770,6 +7038,181 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) return op.Output(0) } +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) + +// EncodeJpegFormat sets the optional format attribute to value. +// +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value + } +} + +// EncodeJpegQuality sets the optional quality attribute to value. +// +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. +// +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value + } +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value + } +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value + } +} + +// EncodeJpegXDensity sets the optional x_density attribute to value. +// +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. +// +// Arguments: +// image: 3-D with shape `[height, width, channels]`. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeJpeg", + Input: []tf.Input{ + image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gradients for batch normalization. +// +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// +// This op is deprecated. See `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +// +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "BatchNormWithGlobalNormalizationGrad", + Input: []tf.Input{ + t, m, v, gamma, backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + // Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. type Conv2DBackpropInputAttr func(optionalAttr) @@ -7160,6 +7603,51 @@ func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes return scope.AddOperation(opspec) } +// Writes contents to the file at input filename. Creates file if not existing. +// +// Arguments: +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. +// +// Returns the created operation. +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteFile", + Input: []tf.Input{ + filename, contents, + }, + } + return scope.AddOperation(opspec) +} + +// Computes the Cholesky decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix Cholesky +// decomposition above. The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cholesky", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns the rank of a tensor. // // This operation returns an integer representing the rank of `input`. @@ -7243,54 +7731,6 @@ func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, opt return output } -// BiasAddGradAttr is an optional argument to BiasAddGrad. -type BiasAddGradAttr func(optionalAttr) - -// BiasAddGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddGradDataFormat(value string) BiasAddGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// The backward operation for "BiasAdd" on the "bias" tensor. -// -// It accumulates all the values from out_backprop into the feature dimension. -// For NHWC data format, the feature dimension is the last. For NCHW data format, -// the feature dimension is the third-to-last. -// -// Arguments: -// out_backprop: Any number of dimensions. -// -// Returns 1-D with size the feature dimension of `out_backprop`. -func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BiasAddGrad", - Input: []tf.Input{ - out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Convert JSON-encoded Example records to binary protocol buffer strings. // // This op translates a tensor containing Example records, encoded using @@ -8024,27 +8464,51 @@ func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output return op.Output(0) } -// Convert one or more images from HSV to RGB. +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) + +// EncodePngCompression sets the optional compression attribute to value. // -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { + return func(m optionalAttr) { + m["compression"] = value + } +} + +// PNG-encode an image. // -// See `rgb_to_hsv` for a description of the HSV encoding. +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: +// +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. +// +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// image: 3-D with shape `[height, width, channels]`. // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HSVToRGB", + Type: "EncodePng", Input: []tf.Input{ - images, + image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -8976,29 +9440,6 @@ func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, return op.Output(0), op.Output(1) } -// Returns the set of files matching one or more glob patterns. -// -// Note that this routine only supports wildcard characters in the -// basename portion of the pattern, not in the directory portion. -// -// Arguments: -// pattern: Shell wildcard pattern(s). Scalar or vector of type string. -// -// Returns A vector of matching filenames. -func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatchingFiles", - Input: []tf.Input{ - pattern, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the gradient of the sigmoid of `x` wrt its input. // // Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and @@ -10269,117 +10710,6 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max return op.Output(0), op.Output(1), op.Output(2) } -// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. -type InitializeTableFromTextFileV2Attr func(optionalAttr) - -// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. -// -// value: Number of elements of the file, use -1 if unknown. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["vocab_size"] = value - } -} - -// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. -// -// value: Delimiter to separate fields in a line. -// If not specified, defaults to "\t" -func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["delimiter"] = value - } -} - -// Initializes a table from a text file. -// -// It inserts one key-value pair into the table for each line of the file. -// The key and value is extracted from the whole line content, elements from the -// split line based on `delimiter` or the line number (starting from zero). -// Where to extract the key and value from a line is specified by `key_index` and -// `value_index`. -// -// - A value of -1 means use the line number(starting from zero), expects `int64`. -// - A value of -2 means use the whole line content, expects `string`. -// - A value >= 0 means use the index (starting at zero) of the split line based -// on `delimiter`. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// filename: Filename of a vocabulary text file. -// key_index: Column index in a line to get the table `key` values from. -// value_index: Column index that represents information of a line to get the table -// `value` values from. -// -// Returns the created operation. -func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InitializeTableFromTextFileV2", - Input: []tf.Input{ - table_handle, filename, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. -type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Sparse update '*var' as FOBOS algorithm with fixed learning rate. -// -// That is for rows we have grad for, we update var as follows: -// prox_v = var - alpha * grad -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalGradientDescent", - Input: []tf.Input{ - var_, alpha, l1, l2, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Computes rectified linear gradients for a Relu operation. // // Arguments: @@ -10420,51 +10750,6 @@ func ReciprocalGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Computes the Cholesky decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix Cholesky -// decomposition above. The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cholesky", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Writes contents to the file at input filename. Creates file if not existing. -// -// Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. -// -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WriteFile", - Input: []tf.Input{ - filename, contents, - }, - } - return scope.AddOperation(opspec) -} - // Reverses specific dimensions of a tensor. // // NOTE `tf.reverse` has now changed behavior in preparation for 1.0. @@ -10627,6 +10912,35 @@ func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Looks up keys in a table, outputs the corresponding values. +// +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. +// +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableFindV2", + Input: []tf.Input{ + table_handle, keys, default_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Given a quantized tensor described by (input, input_min, input_max), outputs a // // range that covers the actual values present in that tensor. This op is @@ -11189,122 +11503,6 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) - -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) - -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", - Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // TensorArrayV3Attr is an optional argument to TensorArrayV3. type TensorArrayV3Attr func(optionalAttr) @@ -11619,54 +11817,6 @@ func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Ou return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) - -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of average pooling function. -// -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // QuantizedRelu6Attr is an optional argument to QuantizedRelu6. type QuantizedRelu6Attr func(optionalAttr) @@ -12745,6 +12895,54 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) + +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of average pooling function. +// +// Arguments: +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool3DGrad", + Input: []tf.Input{ + orig_input_shape, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // TextLineReaderV2Attr is an optional argument to TextLineReaderV2. type TextLineReaderV2Attr func(optionalAttr) @@ -13390,39 +13588,6 @@ func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_ return scope.AddOperation(opspec) } -// Shuffle dimensions of x according to a permutation. -// -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Transpose", - Input: []tf.Input{ - x, perm, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReadFile", - Input: []tf.Input{ - filename, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Output a fact about factorials. func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { @@ -14260,37 +14425,6 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe return op.Output(0) } -// Adjust the contrast of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` -// -// Contrast is adjusted independently for each channel of each image. -// -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. -// -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrastv2", - Input: []tf.Input{ - images, contrast_factor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes softsign gradients for a softsign operation. // // Arguments: @@ -14386,31 +14520,6 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64 return op.Output(0) } -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif -// -// Arguments: -// contents: 0-D. The GIF-encoded image. -// -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeGif", - Input: []tf.Input{ - contents, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // EncodeBase64Attr is an optional argument to EncodeBase64. type EncodeBase64Attr func(optionalAttr) @@ -14672,6 +14781,70 @@ func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment return op.Output(0) } +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) + +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of bilinear interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinearGrad", + Input: []tf.Input{ + grads, original_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the number of elements in the given table. +// +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableSizeV2", + Input: []tf.Input{ + table_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Component-wise divides a SparseTensor by a dense Tensor. // // *Limitation*: this Op only broadcasts the dense side to the sparse side, but not @@ -14727,95 +14900,6 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value return op.Output(0) } -// ProdAttr is an optional argument to Prod. -type ProdAttr func(optionalAttr) - -// ProdKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func ProdKeepDims(value bool) ProdAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the product of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `reduction_indices`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// reduction_indices: The dimensions to reduce. -// -// Returns The reduced tensor. -func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Prod", - Input: []tf.Input{ - input, reduction_indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) - -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using bilinear interpolation. -// -// Input images can be of different types but output images are always float. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinear", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the absolute value of a tensor. // // Given a tensor `x`, this operation returns a tensor containing the absolute @@ -14988,6 +15072,108 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm return op.Output(0) } +// Converts one or more images from RGB to HSV. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// +// Arguments: +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RGBToHSV", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) + +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { + return func(m optionalAttr) { + m["fast"] = value + } +} + +// Solves one or more linear least-squares problems. +// +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] +// in the least squares sense. +// +// matrix and right-hand sides in the batch: +// +// `matrix`=\\(A \in \Re^{m \times n}\\), +// `rhs`=\\(B \in \Re^{m \times k}\\), +// `output`=\\(X \in \Re^{n \times k}\\), +// `l2_regularizer`=\\(\lambda\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + +// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to +// \\(A Z = B\\). Notice that the fast path is only numerically stable when +// \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. +// +// Arguments: +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. +// +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolveLs", + Input: []tf.Input{ + matrix, rhs, l2_regularizer, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // QuantizedReluXAttr is an optional argument to QuantizedReluX. type QuantizedReluXAttr func(optionalAttr) @@ -15770,6 +15956,30 @@ func TanhGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// Outputs all keys and values in the table. +// +// Arguments: +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. type AddManySparseToTensorsMapAttr func(optionalAttr) @@ -15877,6 +16087,153 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o return op.Output(0) } +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) + +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Gather specific elements from the TensorArray into output `value`. +// +// All elements selected by `indices` must have the same shape. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayGatherV3", + Input: []tf.Input{ + handle, indices, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Disallowed in GraphDef version >= 2. +// +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrast", + Input: []tf.Input{ + images, contrast_factor, min_value, max_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) + +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 3D real-valued fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RFFT3D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // UniqueWithCountsAttr is an optional argument to UniqueWithCounts. type UniqueWithCountsAttr func(optionalAttr) @@ -16708,6 +17065,30 @@ func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, o return op.Output(0), op.Output(1), op.Output(2) } +// Updates the table to associates keys with values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + // Produces the average pool of the input tensor for quantized types. // // Arguments: @@ -16997,41 +17378,6 @@ func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Out return op.Output(0) } -// Draw bounding boxes on a batch of images. -// -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, if an image is 100 x 200 pixels and the bounding box is -// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the -// bounding box will be `(10, 40)` to `(50, 180)`. -// -// Parts of the bounding box may fall outside the image. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. -// -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", - Input: []tf.Input{ - images, boxes, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the element-wise max of two SparseTensors. // // Assumes the two SparseTensors have the same shape, i.e., no broadcasting. @@ -17501,28 +17847,6 @@ func Log(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes rectified linear 6 gradients for a Relu6 operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu6 operation. -// features: The features passed as input to the corresponding Relu6 operation. -// -// Returns The gradients: -// `gradients * (features > 0) * (features < 6)`. -func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu6Grad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResizeBicubicAttr is an optional argument to ResizeBicubic. type ResizeBicubicAttr func(optionalAttr) @@ -17568,6 +17892,28 @@ func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...R return op.Output(0) } +// Computes rectified linear 6 gradients for a Relu6 operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu6 operation. +// features: The features passed as input to the corresponding Relu6 operation. +// +// Returns The gradients: +// `gradients * (features > 0) * (features < 6)`. +func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6Grad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes natural logarithm of (1 + x) element-wise. // // I.e., \\(y = \log_e (1 + x)\\). @@ -17681,181 +18027,6 @@ func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Outp return op.Output(0) } -// Gradients for batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. See `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -// -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", - Input: []tf.Input{ - t, m, v, gamma, backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} - -// EncodeJpegQuality sets the optional quality attribute to value. -// -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value - } -} - -// EncodeJpegProgressive sets the optional progressive attribute to value. -// -// value: If True, create a JPEG that loads progressively (coarse to fine). -// If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["progressive"] = value - } -} - -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value - } -} - -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value - } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value - } -} - -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeJpeg", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes sin of x element-wise. func Sin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -18164,6 +18335,117 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output) (output tf.Outpu return op.Output(0) } +// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. +type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update '*var' as FOBOS algorithm with fixed learning rate. +// +// That is for rows we have grad for, we update var as follows: +// prox_v = var - alpha * grad +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyProximalGradientDescent", + Input: []tf.Input{ + var_, alpha, l1, l2, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) + +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. +// +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["vocab_size"] = value + } +} + +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["delimiter"] = value + } +} + +// Initializes a table from a text file. +// +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. +// +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InitializeTableFromTextFileV2", + Input: []tf.Input{ + table_handle, filename, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Computes atan of x element-wise. func Atan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -18628,33 +18910,36 @@ func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. -type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) +// BiasAddGradAttr is an optional argument to BiasAddGrad. +type BiasAddGradAttr func(optionalAttr) -// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. +// BiasAddGradDataFormat sets the optional data_format attribute to value. // -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddGradDataFormat(value string) BiasAddGradAttr { return func(m optionalAttr) { - m["num_bits"] = value + m["data_format"] = value } } -// Compute gradients for a FakeQuantWithMinMaxVars operation. +// The backward operation for "BiasAdd" on the "bias" tensor. +// +// It accumulates all the values from out_backprop into the feature dimension. +// For NHWC data format, the feature dimension is the last. For NCHW data format, +// the feature dimension is the third-to-last. // // Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. -// min, max: Quantization interval, scalar floats. +// out_backprop: Any number of dimensions. // -// -// -// Returns Backpropagated gradients w.r.t. inputs: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: -// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: -// `sum(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { +// Returns 1-D with size the feature dimension of `out_backprop`. +func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -18663,31 +18948,13 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t a(attrs) } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsGradient", + Type: "BiasAddGrad", Input: []tf.Input{ - gradients, inputs, min, max, + out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. -// -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Minimum", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) return op.Output(0) } @@ -19996,65 +20263,6 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp return op.Output(0), op.Output(1), op.Output(2) } -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) - -// DecodePngChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Decode a PNG-encoded image to a uint8 or uint16 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. -// -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. -// -// Arguments: -// contents: 0-D. The PNG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodePng", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // AudioSummaryV2Attr is an optional argument to AudioSummaryV2. type AudioSummaryV2Attr func(optionalAttr) @@ -20219,31 +20427,188 @@ func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate flo return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) - -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// Replaces the contents of the table with the specified keys and values. // -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableImportV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) + +// HashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { return func(m optionalAttr) { - m["align_corners"] = value + m["container"] = value } } -// Resize `images` to `size` using nearest neighbor interpolation. +// HashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates a non-initialized hash table. +// +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "HashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) + +// MutableHashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) + +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -20252,12 +20617,164 @@ func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optio a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - images, size, + var_, accum, lr, l1, l2, grad, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Table initializer that takes two tensors for keys and values respectively. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. +// +// Returns the created operation. +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InitializeTableV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. +type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. +// +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVars operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. +// min, max: Quantization interval, scalar floats. +// +// +// +// Returns Backpropagated gradients w.r.t. inputs: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: +// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: +// `sum(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsGradient", + Input: []tf.Input{ + gradients, inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Minimum", + Input: []tf.Input{ + x, y, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } @@ -20385,6 +20902,84 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha return op.Output(0) } +// Adjust the saturation of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. +// +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustSaturation", + Input: []tf.Input{ + images, scale, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. +type SelfAdjointEigV2Attr func(optionalAttr) + +// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. +// +// value: If `True` then eigenvectors will be computed and returned in `v`. +// Otherwise, only the eigenvalues will be computed. +// If not specified, defaults to true +func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { + return func(m optionalAttr) { + m["compute_v"] = value + } +} + +// Computes the eigen decomposition of one or more square self-adjoint matrices. +// +// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. +// +// ```prettyprint +// # a is a tensor. +// # e is a tensor of eigenvalues. +// # v is a tensor of eigenvectors. +// e, v = self_adjoint_eig(a) +// e = self_adjoint_eig(a, compute_v=False) +// ``` +// +// Arguments: +// input: `Tensor` input of shape `[N, N]`. +// +// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. +func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SelfAdjointEigV2", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // MatrixSolveAttr is an optional argument to MatrixSolve. type MatrixSolveAttr func(optionalAttr) @@ -21033,371 +21628,6 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O return op.Output(0), op.Output(1) } -// Computes the number of elements in the given table. -// -// Arguments: -// table_handle: Handle to the table. -// -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", - Input: []tf.Input{ - table_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) - -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of bilinear interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", - Input: []tf.Input{ - grads, original_image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of nearest neighbor interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", - Input: []tf.Input{ - grads, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeJpegAttr is an optional argument to DecodeJpeg. -type DecodeJpegAttr func(optionalAttr) - -// DecodeJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeJpegChannels(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeJpegRatio(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} - -// DecodeJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeJpegDctMethod(value string) DecodeJpegAttr { - return func(m optionalAttr) { - m["dct_method"] = value - } -} - -// Decode a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// -// Arguments: -// contents: 0-D. The JPEG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeJpeg", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) - -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Gather specific elements from the TensorArray into output `value`. -// -// All elements selected by `indices` must have the same shape. -// -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. -// -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) - -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT3D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deprecated. Disallowed in GraphDef version >= 2. -// -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrast", - Input: []tf.Input{ - images, contrast_factor, min_value, max_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Store the input tensor in the state of the current session. // // Arguments: @@ -21419,25 +21649,6 @@ func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { return op.Output(0) } -// Restore a Reader to its initial clean state. -// -// Arguments: -// reader_handle: Handle to a Reader. -// -// Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderResetV2", - Input: []tf.Input{ - reader_handle, - }, - } - return scope.AddOperation(opspec) -} - // Adjust the hue of one or more images. // // `images` is a tensor of at least 3 dimensions. The last dimension is @@ -21466,232 +21677,21 @@ func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Outpu return op.Output(0) } -// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. -type SelfAdjointEigV2Attr func(optionalAttr) - -// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. -// -// value: If `True` then eigenvectors will be computed and returned in `v`. -// Otherwise, only the eigenvalues will be computed. -// If not specified, defaults to true -func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { - return func(m optionalAttr) { - m["compute_v"] = value - } -} - -// Computes the eigen decomposition of one or more square self-adjoint matrices. -// -// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in -// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. -// -// ```prettyprint -// # a is a tensor. -// # e is a tensor of eigenvalues. -// # v is a tensor of eigenvectors. -// e, v = self_adjoint_eig(a) -// e = self_adjoint_eig(a, compute_v=False) -// ``` +// Restore a Reader to its initial clean state. // // Arguments: -// input: `Tensor` input of shape `[N, N]`. +// reader_handle: Handle to a Reader. // -// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. -func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SelfAdjointEigV2", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Adjust the saturation of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. -// -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { +// Returns the created operation. +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustSaturation", + Type: "ReaderResetV2", Input: []tf.Input{ - images, scale, + reader_handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) - -// EncodePngCompression sets the optional compression attribute to value. -// -// value: Compression level. -// If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { - return func(m optionalAttr) { - m["compression"] = value - } -} - -// PNG-encode an image. -// -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: -// -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. -// -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodePng", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] -// in the least squares sense. -// -// matrix and right-hand sides in the batch: -// -// `matrix`=\\(A \in \Re^{m \times n}\\), -// `rhs`=\\(B \in \Re^{m \times k}\\), -// `output`=\\(X \in \Re^{n \times k}\\), -// `l2_regularizer`=\\(\lambda\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + -// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as -// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to -// \\(A Z = B\\). Notice that the fast path is only numerically stable when -// \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is -// sufficiently large. -// -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. -// -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility -// -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolveLs", - Input: []tf.Input{ - matrix, rhs, l2_regularizer, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts one or more images from RGB to HSV. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. -// -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. -// -// Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. -// -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RGBToHSV", - Input: []tf.Input{ - images, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5e938c73f5a..9fd5ada71ee 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -82,6 +82,7 @@ py_library( "//third_party/py/numpy", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column", + "//tensorflow/python/feature_column:lookup_ops", "//tensorflow/python/ops/losses", "//tensorflow/python/ops/distributions", "//tensorflow/python/saved_model", @@ -1021,7 +1022,6 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/contrib/lookup:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", ], ) @@ -1056,6 +1056,16 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "lookup_ops_gen", + require_shape_functions = True, + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + tf_gen_op_wrapper_private_py( name = "math_ops_gen", require_shape_functions = True, @@ -1473,6 +1483,20 @@ py_library( ], ) +py_library( + name = "lookup_ops", + srcs = ["ops/lookup_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":framework", + ":framework_for_generated_wrappers", + ":lookup_ops_gen", + ":math_ops", + "@six_archive//:six", + ], +) + py_library( name = "math_grad", srcs = ["ops/math_grad.py"], @@ -1861,6 +1885,7 @@ py_library( ":io_ops", ":linalg_ops", ":logging_ops", + ":lookup_ops", ":math_grad", ":math_ops", ":numerics", @@ -2268,6 +2293,7 @@ py_library( ":io_ops", ":io_ops_gen", ":lib", + ":lookup_ops", ":math_ops", ":platform", ":protos_all_py", @@ -2990,6 +3016,7 @@ cuda_py_tests( ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", + ":lookup_ops", ":gradients", ":math_ops", ":nn_grad", @@ -3020,7 +3047,7 @@ py_library( srcs = ["training/saver_test_utils.py"], srcs_version = "PY2AND3", deps = [ - ":data_flow_ops_gen", + ":lookup_ops_gen", ":training", ], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 864a96ef348..6336ca23105 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import * from tensorflow.core.framework.attr_value_pb2 import * from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo from tensorflow.core.protobuf.config_pb2 import * +from tensorflow.core.protobuf.tensorflow_server_pb2 import * from tensorflow.core.protobuf.rewriter_config_pb2 import * from tensorflow.core.util.event_pb2 import * @@ -131,6 +132,7 @@ _allowed_symbols = [ 'AttrValue', 'AutoParallelOptions', 'ConfigProto', + 'ClusterDef', 'DeviceSpec', 'Event', 'GPUOptions', diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 9add5bd3cde..040cc333158 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -29,6 +29,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes @@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase): with CaptureStderr() as log: sess.run(c) # Ensure that we did log device placement. - self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log)) + self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log)) def testLocalMasterSessionTimeout(self): # Test that the timeout passed in a config to the session works correctly. @@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase): server = server_lib.Server.create_local_server() self.runTestBuildGraphError(session.Session(server.target)) + def testClusterSpecPropagationSimple(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationWorker2Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + output = sess.run(const, options=run_options, run_metadata=run_metadata) + self.assertEqual(17, output) + self.assertEqual(1, + len([ + node_stats + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:1/device:CPU:0' == + dev_stats.device and 'Const' == node_stats.node_name + ])) + + def testClusterSpecPropagationWorker1Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationThreeServers2Graphs(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker1' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker2' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.Graph().as_default() as g1: + with ops.device('/job:worker1/task:1'): + var1 = variables.Variable(array_ops.zeros([2]), name='var1') + update_op1 = state_ops.assign_add( + var1, array_ops.ones([2]), name='var1_assign_add') + init1 = variables.global_variables_initializer() + + with ops.Graph().as_default() as g2: + with ops.device('/job:worker2/task:1'): + var2 = variables.Variable(array_ops.zeros([2]), name='var2') + update_op2 = state_ops.assign_add( + var2, array_ops.ones([2]), name='var2_assign_add') + init2 = variables.global_variables_initializer() + + sess1 = session.Session(server2.target, graph=g1, config=config1) + sess2 = session.Session(server2.target, graph=g2, config=config2) + + init1.run(session=sess1) + init2.run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + + self.assertAllEqual(expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + self.assertAllEqual(expected_ones, sess2.run(update_op2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess2.run(var2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) + + def testClusterSpecPropagationThreeServers(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.device('/job:worker/task:1'): + var = variables.Variable(array_ops.zeros([2]), name='var') + feed = array_ops.placeholder(dtypes.float32, shape=(2)) + update_op = var.assign_add(feed) + + sess1 = session.Session(server2.target, config=config1) + sess2 = session.Session(server2.target, config=config2) + + variables.global_variables_initializer().run(session=sess1) + variables.global_variables_initializer().run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess2.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones + expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess2.run(var)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var)) + + def testClusterSpecPropagationThreeServersOneCluster(self): + """Boots 3 servers, ensures appropriate communication across workers. + + Additionally, in this cluster, we ensure the master is not the 0-th worker. + + Note: this test only uses one session. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server3.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + job.tasks[2] = server1.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + # Add ops to the devices in non-linear order. + + with ops.device('/job:worker/task:1'): + feed1 = array_ops.placeholder(dtypes.float32, shape=(2)) + const1 = constant_op.constant(2.0) + mul1 = const1 * feed1 + + with ops.device('/job:worker/task:2'): + feed2 = array_ops.placeholder(dtypes.float32, shape=(2)) + const2 = constant_op.constant(2.0) + mul2 = const2 * feed2 + + with ops.device('/job:worker/task:0'): + feed0 = array_ops.placeholder(dtypes.float32, shape=(2)) + const0 = constant_op.constant(2.0) + mul0 = const0 * feed0 + + sum_op = mul0 + mul1 + mul2 + + ones = np.ones([2]) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + + # Run! + with session.Session(server1.target, config=config) as sess: + output = sess.run( + sum_op, + options=run_options, + run_metadata=run_metadata, + feed_dict={feed1: ones, + feed2: ones, + feed0: ones}) + self.assertAllEqual(6 * ones, output) + + self.assertEqual( + 3, + len([ + dev_stats.device + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:' in dev_stats.device and + node_stats.node_name.startswith('Const') + ]), run_metadata) + + def testClusterSpecPropagationPartialRun(self): + """Test successful partial run with ClusterSpec propagation.""" + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.device('/job:worker/task:0'): + a = array_ops.placeholder(dtypes.float32, shape=[]) + with ops.device('/job:worker/task:1'): + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + with ops.device('/job:worker/task:0'): + r2 = math_ops.multiply(r1, c) + + with session.Session(server1.target, config=config) as sess: + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + res = sess.partial_run(h, r2, feed_dict={c: 3}) + self.assertEqual(9, res) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 7a12ec01d07..3b7e3b1c904 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -42,8 +42,8 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops @@ -1396,9 +1396,10 @@ class EstimatorExportTest(test.TestCase): my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) scores = constant_op.constant([3.]) - with ops.control_dependencies( - [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()]): + with ops.control_dependencies([ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ]): assign_op = state_ops.assign(my_int, 12345) # local_initSop must be an Operation, not a Tensor. diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 37a98cf4815..a1ecd794df6 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -23,6 +23,8 @@ import collections import os import time +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError('feature keys must be strings: {}.'.format(name)) if not (isinstance(tensor, ops.Tensor) or isinstance(tensor, sparse_tensor.SparseTensor)): @@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} for name, tensor in receiver_tensors.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError( 'receiver_tensors keys must be strings: {}.'.format(name)) if not isinstance(tensor, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 69be0f687c1..49bcd06d504 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -171,7 +173,7 @@ class PredictOutput(ExportOutput): 'Prediction outputs must be given as a dict of string to Tensor; ' 'got {}'.format(outputs)) for key, value in outputs.items(): - if not isinstance(key, str): + if not isinstance(key, six.string_types): raise ValueError( 'Prediction output key must be a string; got {}.'.format(key)) if not isinstance(value, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 27a088e551c..035a9a143e6 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase): signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(actual_signature_def, expected_signature_def) + def test_predict_output_constructor(self): + """Tests that no errors are raised when input is expected.""" + outputs = { + "output0": constant_op.constant([0]), + u"output1": constant_op.constant([1]), + } + export_output_lib.PredictOutput(outputs) + + def test_predict_output_outputs_invalid(self): + with self.assertRaisesRegexp( + ValueError, + "Prediction outputs must be given as a dict of string to Tensor"): + export_output_lib.PredictOutput(constant_op.constant([0])) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output key must be a string"): + export_output_lib.PredictOutput({1: constant_op.constant([0])}) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output value must be a Tensor"): + export_output_lib.PredictOutput({ + "prediction1": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index fdd924f2e1c..7946bd88ba0 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.estimator.export import export from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import constant_op -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils class ExportTest(test_util.TensorFlowTestCase): + def test_serving_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.ServingInputReceiver(features, receiver_tensors) + + def test_serving_input_receiver_features_invalid(self): + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.ServingInputReceiver( + features=None, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.ServingInputReceiver( + features={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.ServingInputReceiver( + features={"feature1": [1]}, + receiver_tensors=receiver_tensors) + + def test_serving_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.ServingInputReceiver( + features=features, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.ServingInputReceiver( + features=features, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.ServingInputReceiver( + features=features, + receiver_tensors={"example1": [1]}) + def test_single_feature_single_receiver(self): feature = constant_op.constant(5) receiver_tensor = array_ops.placeholder(dtypes.string) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index d5eb20e997c..ac7aef96ac1 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -29,6 +29,7 @@ py_library( srcs = ["feature_column.py"], srcs_version = "PY2AND3", deps = [ + ":lookup_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:init_ops", @@ -44,14 +45,47 @@ py_library( ], ) +filegroup( + name = "vocabulary_testdata", + srcs = [ + "testdata/warriors_vocabulary.txt", + "testdata/wire_vocabulary.txt", + ], +) + py_test( name = "feature_column_test", srcs = ["feature_column_test.py"], + data = [":vocabulary_testdata"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//tensorflow/python:training", ], ) + +# TODO(ptucker,yleon): Move along with 3p/tf/contrib/lookup. +# Test is still in 3p/tf/contrib/lookup. +py_library( + name = "lookup_ops", + srcs = [ + "lookup_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lookup_ops_gen", + "//tensorflow/python:math_ops", + "//tensorflow/python:string_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index a96052a3ae5..ffdf8868e21 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -121,6 +121,9 @@ from __future__ import print_function import abc import collections +import numpy as np + +from tensorflow.python.feature_column import lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -331,7 +334,9 @@ def numeric_column(key, ``` Args: - key: A string providing key to look up corresponding `Tensor`. + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. shape: An iterable of integers specifies the shape of the `Tensor`. An integer can be given which means a single dimension `Tensor` with given width. The `Tensor` representing the column will have the shape of @@ -430,6 +435,12 @@ def bucketized_column(source_column, boundaries): return _BucketizedColumn(source_column, tuple(boundaries)) +def _assert_string_or_int(dtype, prefix): + if (dtype != dtypes.string) and (not dtype.is_integer): + raise ValueError( + '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) + + def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -443,22 +454,22 @@ def categorical_column_with_hash_bucket(key, ```python keywords = categorical_column_with_hash_bucket("keywords", 10K) - all_feature_columns = [keywords, ...] - linear_prediction = make_linear_model(features, all_feature_columns) + linear_prediction = make_linear_model(features, [keywords, ...]) # or keywords_embedded = embedding_column(keywords, 16) - all_feature_columns = [keywords_embedded, ...] - dense_tensor = make_input_layer(features, all_feature_columns) + dense_tensor = make_input_layer(features, [keywords_embedded, ...]) ``` Args: - key: A string providing key to look up corresponding `Tensor`. + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. hash_bucket_size: An int > 1. The number of buckets. dtype: The type of features. Only string and integer types are supported. Returns: - A `_CategoricalColumnHashed`. + A `_HashedCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. @@ -472,11 +483,177 @@ def categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError('dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format(dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) - return _CategoricalColumnHashed(key, hash_bucket_size, dtype) + return _HashedCategoricalColumn(key, hash_bucket_size, dtype) + + +def categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A `_CategoricalColumn` with a vocabulary file. + + Use this when your inputs are in string or integer format, and you have a + vocabulary file that maps each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use either (but not both) of + `num_oov_buckets` and `default_value` to specify how to include + out-of-vocabulary values. + + Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values + can be represented by `-1` for int and `''` for string. Note that these values + are independent of the `default_value` argument. + + Example with `num_oov_buckets`: + File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state + abbreviation. All inputs with values in that file are assigned an ID 0-49, + corresponding to its line number. All other values are hashed and assigned an + ID 50-54. + ```python + states = categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + linear_prediction = make_linear_model(features, [states, ...]) + ``` + + Example with `default_value`: + File '/us/states.txt' contains 51 lines - the first line is 'XX', and the + other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' + in input, and other values missing from the file, will be assigned ID 0. All + others are assigned the corresponding line number 1-50. + ```python + states = categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, + default_value=0) + linear_prediction, _, _ = make_linear_model(features, [states, ...]) + + And to make an embedding with either: + ```python + dense_tensor = make_input_layer(features, [embedding_column(states, 3),...]) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to -1. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_CategoricalColumn` with a vocabulary file. + + Raises: + ValueError: `vocabulary_file` is missing. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is not a non-negative integer. + ValueError: `dtype` is neither string nor integer. + """ + if not vocabulary_file: + raise ValueError('Missing vocabulary_file in {}.'.format(key)) + # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. + # TODO(ptucker): Should we fail for vocabulary_size==1? + if (vocabulary_size is None) or (vocabulary_size < 1): + raise ValueError('Invalid vocabulary_size in {}.'.format(key)) + if num_oov_buckets: + if default_value is not None: + raise ValueError( + 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( + key)) + if num_oov_buckets < 0: + raise ValueError('Invalid num_oov_buckets {} in {}.'.format( + num_oov_buckets, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + return _VocabularyFileCategoricalColumn( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, + default_value=-1 if default_value is None else default_value, + dtype=dtype) + + +def categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1): + """A `_CategoricalColumn` with in-memory vocabulary. + + Logic for feature f is: + id = f in vocabulary_list ? vocabulary_list.index(f) : default_value + + Use this when your inputs are in string or integer format, and you have an + in-memory vocabulary mapping each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use `default_value` to specify how to + include out-of-vocabulary values. + + Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values + can be represented by `-1` for int and `''` for string. Note that these values + are independent of the `default_value` argument. + + In the following examples, each input in `vocabulary_list` is assigned an ID + 0-4 corresponding to its index (e.g., input 'B' produces output 2). All other + inputs are assigned `default_value` 0. + + Linear model: + ```python + colors = categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) + linear_prediction, _, _ = make_linear_model(features, [colors, ...]) + ``` + + Embedding for a DNN model: + ```python + dense_tensor = make_input_layer(features, [embedding_column(colors, 3),...]) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The value to use for values not in `vocabulary_list`. + + Returns: + A `_CategoricalColumn` with in-memory vocabulary. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: if `dtype` is not integer or string. + """ + if (vocabulary_list is None) or (len(vocabulary_list) < 1): + raise ValueError( + 'vocabulary_list {} must be non-empty, column_name: {}'.format( + vocabulary_list, key)) + if len(set(vocabulary_list)) != len(vocabulary_list): + raise ValueError( + 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( + vocabulary_list, key)) + vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) + _assert_string_or_int( + vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) + if dtype is None: + dtype = vocabulary_dtype + elif dtype.is_integer != vocabulary_dtype.is_integer: + raise ValueError( + 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( + dtype, vocabulary_dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + + return _VocabularyListCategoricalColumn( + key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype, + default_value=default_value) class _FeatureColumn(object): @@ -764,6 +941,67 @@ class _LazyBuilder(object): return transformed +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _shape_offsets(shape): + """Returns moving offset for each dimension given shape.""" + offsets = [] + for dim in reversed(shape): + if offsets: + offsets.append(dim * offsets[-1]) + else: + offsets.append(dim) + offsets.reverse() + return offsets + + +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _to_sparse_input(input_tensor, ignore_value=None): + """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. + + If `input_tensor` is already a `SparseTensor`, just return it. + + Args: + input_tensor: A string or integer `Tensor`. + ignore_value: Entries in `dense_tensor` equal to this value will be + absent from the resulting `SparseTensor`. If `None`, default value of + `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). + + Returns: + A `SparseTensor` with the same shape as `input_tensor`. + + Raises: + ValueError: when `input_tensor`'s rank is `None`. + """ + input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + input_tensor) + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + return input_tensor + with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): + input_rank = input_tensor.get_shape().ndims + if input_rank is None: + # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank. + raise ValueError('Undefined input_tensor shape.') + if ignore_value is None: + ignore_value = '' if input_tensor.dtype == dtypes.string else -1 + dense_shape = math_ops.cast(array_ops.shape(input_tensor), dtypes.int64) + indices = array_ops.where(math_ops.not_equal( + input_tensor, math_ops.cast(ignore_value, input_tensor.dtype))) + # Flattens the tensor and indices for use with gather. + flat_tensor = array_ops.reshape(input_tensor, [-1]) + flat_indices = indices[:, input_rank - 1] + # Computes the correct flattened indices for 2d (or higher) tensors. + if input_rank > 1: + higher_dims = indices[:, :input_rank - 1] + shape_offsets = array_ops.stack( + _shape_offsets(array_ops.unstack(dense_shape)[1:])) + offsets = math_ops.reduce_sum( + math_ops.multiply(higher_dims, shape_offsets), + reduction_indices=[1]) + flat_indices = math_ops.add(flat_indices, offsets) + values = array_ops.gather(flat_tensor, flat_indices) + return sparse_tensor_lib.SparseTensor(indices, values, dense_shape) + + def _check_feature_columns(feature_columns): if isinstance(feature_columns, dict): raise ValueError('Expected feature_columns to be iterable, found dict.') @@ -951,7 +1189,7 @@ def _check_default_value(shape, default_value, dtype, key): `shape`. dtype: defines the type of values. Default value is `tf.float32`. Must be a non-quantized, real integer or floating point type. - key: A string providing key to look up corresponding `Tensor`. + key: Column name, used only for error messages. Returns: A tuple which will be used as default value. @@ -994,9 +1232,9 @@ def _check_default_value(shape, default_value, dtype, key): default_value, dtype, key)) -class _CategoricalColumnHashed( +class _HashedCategoricalColumn( _CategoricalColumn, - collections.namedtuple('_CategoricalColumnHashed', + collections.namedtuple('_HashedCategoricalColumn', ['key', 'hash_bucket_size', 'dtype'])): """see `categorical_column_with_hash_bucket`.""" @@ -1009,15 +1247,13 @@ class _CategoricalColumnHashed( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = inputs.get(self.key) + input_tensor = _to_sparse_input(inputs.get(self.key)) if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') - if (input_tensor.dtype != dtypes.string and - not input_tensor.dtype.is_integer): - raise ValueError('input tensors dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format( - input_tensor.dtype, self.key)) + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -1045,6 +1281,109 @@ class _CategoricalColumnHashed( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) +class _VocabularyFileCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyFileCategoricalColumn', ( + 'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype', + 'default_value' + ))): + """See `categorical_column_with_vocabulary_file`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_file` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + return lookup_ops.index_table_from_file( + vocabulary_file=self.vocabulary_file, + num_oov_buckets=self.num_oov_buckets, + vocab_size=self.vocabulary_size, + default_value=self.default_value, + key_dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.vocabulary_size + self.num_oov_buckets + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + +class _VocabularyListCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyListCategoricalColumn', ( + 'key', 'vocabulary_list', 'dtype', 'default_value' + ))): + """See `categorical_column_with_vocabulary_list`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_tensor` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + return lookup_ops.index_table_from_tensor( + mapping=tuple(self.vocabulary_list), + default_value=self.default_value, + dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return len(self.vocabulary_list) + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + # TODO(zakaria): Move this to embedding_ops and make it public. def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index bc626533104..59aa39411f5 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -28,9 +28,10 @@ from tensorflow.python.client import session from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -40,7 +41,7 @@ from tensorflow.python.platform import test def _initialized_session(): sess = session.Session() sess.run(variables_lib.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) return sess @@ -552,7 +553,7 @@ class BucketizedColumnTest(test.TestCase): self.assertAllClose([[81.], [141.]], predictions.eval()) -class SparseColumnHashedTest(test.TestCase): +class HashedCategoricalColumnTest(test.TestCase): def test_defaults(self): a = fc.categorical_column_with_hash_bucket('aaa', 10) @@ -578,11 +579,14 @@ class SparseColumnHashedTest(test.TestCase): def test_deep_copy(self): """Tests deepcopy of categorical_column_with_hash_bucket.""" - column = fc.categorical_column_with_hash_bucket('aaa', 10) - column_copy = copy.deepcopy(column) - self.assertEqual('aaa', column_copy.name) - self.assertEqual(10, column_copy.hash_bucket_size) - self.assertEqual(dtypes.string, column_copy.dtype) + original = fc.categorical_column_with_hash_bucket('aaa', 10) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(10, column.hash_bucket_size) + # pylint: disable=protected-access + self.assertEqual(10, column._num_buckets) + # pylint: enable=protected-access + self.assertEqual(dtypes.string, column.dtype) def test_parse_config(self): a = fc.categorical_column_with_hash_bucket('aaa', 10) @@ -681,14 +685,45 @@ class SparseColumnHashedTest(test.TestCase): def test_get_sparse_tensors(self): hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - builder = fc._LazyBuilder({'wire': wire_tensor}) - self.assertEqual( - builder.get(hashed_sparse), - hashed_sparse._get_sparse_tensors(builder).id_tensor) + builder = fc._LazyBuilder({ + 'wire': sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + }) + id_weight_pair = hashed_sparse._get_sparse_tensors(builder) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor) + + def test_get_sparse_tensors_dense_input(self): + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + builder = fc._LazyBuilder({ + 'wire': (('omar', ''), ('stringer', 'marlo')) + }) + id_weight_pair = hashed_sparse._get_sparse_tensors(builder) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_hash_bucket('wire', 4) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 3: wire_var[3] = 4 + # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 + self.assertAllClose(((4.,), (6.,)), predictions.eval()) def get_linear_model_bias(): @@ -1158,5 +1193,640 @@ class MakeInputLayerTest(test.TestCase): self.assertAllClose([[1., 3.]], net2.eval()) +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class VocabularyFileCategoricalColumnTest(test.TestCase): + + def setUp(self): + super(VocabularyFileCategoricalColumnTest, self).setUp() + + # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22 + self._warriors_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/warriors_vocabulary.txt') + self._warriors_vocabulary_size = 5 + + # Contains strings, character names from 'The Wire': omar, stringer, marlo + self._wire_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def test_defaults(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + # pylint: disable=protected-access + self.assertEqual(7, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(7, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_vocabulary_file_none(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=None, vocabulary_size=3) + + def test_vocabulary_file_empty_string(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='', vocabulary_size=3) + + def test_invalid_vocabulary_file(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): + with self.test_session(): + lookup_ops.tables_initializer().run() + + def test_invalid_vocabulary_size(self): + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=None) + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=-1) + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=0) + + def test_too_large_vocabulary_size(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size + 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): + with self.test_session(): + lookup_ops.tables_initializer().run() + + def test_invalid_num_oov_buckets(self): + with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + num_oov_buckets=-1) + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + dtype=dtypes.float64) + + def test_invalid_buckets_and_default_value(self): + with self.assertRaisesRegexp( + ValueError, 'both num_oov_buckets and default_value'): + fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100, + default_value=2) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + dtype=dtypes.string) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (1, 2)), + values=('marlo', 'skywalker', 'omar', 'heisenberg'), + dense_shape=(2, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 33, 0, 62), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_small_vocabulary_size(self): + # 'marlo' is the last entry in our vocabulary file, so be setting + # `vocabulary_size` to 1 less than number of entries in file, we take + # 'marlo' out of the vocabulary. + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size - 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((-1, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + default_value=default_value) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22)) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 60, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_file( + key='wire', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + +class VocabularyListCategoricalColumnTest(test.TestCase): + + def test_defaults_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_defaults_int(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36)) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32, + default_value=-99) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.float32) + + def test_invalid_mapping_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12., 24., 36.)) + + def test_mismatched_int_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.int32) + + def test_mismatched_string_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string) + + def test_none_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=None) + + def test_empty_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=tuple([])) + + def test_duplicate_mapping(self): + with self.assertRaisesRegexp(ValueError, 'Duplicate keys'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 12)) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=(12, 24, 36)) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((11, 100, 30, 22), dtype=np.int32), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32, + default_value=default_value) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': np.array( + ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), + dtype=np.int32) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual(3, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> None, 'omar' -> 0: wire_var[0] = 1 + self.assertAllClose(((3.,), (1.,)), predictions.eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py similarity index 96% rename from tensorflow/contrib/lookup/lookup_ops.py rename to tensorflow/python/feature_column/lookup_ops.py index 9dc7414cd07..8225b47b204 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/python/feature_column/lookup_ops.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Lookup table Operations.""" -# pylint: disable=g-bad-name +"""Lookup table operations.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -27,7 +27,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.training.saver import BaseSaverBuilder @@ -151,7 +151,7 @@ class InitializableLookupTableBase(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as scope: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope) # pylint: enable=protected-access def lookup(self, keys, name=None): @@ -182,7 +182,7 @@ class InitializableLookupTableBase(LookupInterface): name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, key_tensor, self._default_value, name=scope) # pylint: enable=protected-access @@ -229,7 +229,7 @@ class HashTable(InitializableLookupTableBase): with ops.name_scope( name, "hash_table", (initializer, default_value)) as scope: # pylint: disable=protected-access - table_ref = gen_data_flow_ops._hash_table( + table_ref = gen_lookup_ops._hash_table( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, @@ -308,10 +308,8 @@ class KeyValueTensorInitializer(TableInitializerBase): self._name, values=(table.table_ref, self._keys, self._values)) as scope: # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table(table.table_ref, - self._keys, - self._values, - name=scope) + init_op = gen_lookup_ops._initialize_table( + table.table_ref, self._keys, self._values, name=scope) # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op @@ -477,7 +475,7 @@ class TextFileInitializer(TableInitializerBase): dtypes.string, name="asset_filepath") # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table_from_text_file( + init_op = gen_lookup_ops._initialize_table_from_text_file( table.table_ref, filename, self._key_index, @@ -608,7 +606,7 @@ class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): __slots__ = () -FastHashSpec = HasherSpec("fasthash", None) +FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name class StrongHashSpec(HasherSpec): @@ -1333,14 +1331,14 @@ class MutableHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_data_flow_ops._mutable_hash_table( + self._table_ref = gen_lookup_ops._mutable_hash_table( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: - self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( + self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, @@ -1368,7 +1366,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1394,10 +1392,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find(self._table_ref, - keys, - self._default_value, - name=name) + values = gen_lookup_ops._lookup_table_find( + self._table_ref, keys, self._default_value, name=name) values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values @@ -1423,7 +1419,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1440,11 +1436,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( - self._table_ref, - self._key_dtype, - self._value_dtype, - name=name) + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( + self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -1464,7 +1457,7 @@ class MutableHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op._table_ref, restored_tensors[0], restored_tensors[1]) @@ -1539,7 +1532,7 @@ class MutableDenseHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) # pylint: disable=protected-access - self._table_ref = gen_data_flow_ops._mutable_dense_hash_table( + self._table_ref = gen_lookup_ops._mutable_dense_hash_table( empty_key=empty_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, @@ -1567,7 +1560,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1593,7 +1586,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, keys, self._default_value, name=name) if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: @@ -1623,7 +1616,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1640,7 +1633,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( @@ -1661,6 +1654,5 @@ class MutableDenseHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import(self.op._table_ref, - restored_tensors[0], - restored_tensors[1]) + return gen_lookup_ops._lookup_table_import( + self.op._table_ref, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt new file mode 100644 index 00000000000..6c917fa6999 --- /dev/null +++ b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt @@ -0,0 +1,5 @@ +30 +35 +11 +23 +22 diff --git a/tensorflow/python/feature_column/testdata/wire_vocabulary.txt b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt new file mode 100644 index 00000000000..32c6b5692a0 --- /dev/null +++ b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt @@ -0,0 +1,3 @@ +omar +stringer +marlo diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 452cf3be703..0b04904ec23 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -70,25 +70,33 @@ def _override_helper(clazz_object, operator, func): setattr(clazz_object, operator, func) -def _convert_stack(stack): +def _convert_stack(stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. Args: - stack: A list of n 4-tuples, (filename, lineno, name, frame_globals). + stack: A list of n 5-tuples, + (filename, lineno, name, frame_globals, func_start_lineno). + include_func_start_lineno: True if function start line number should be + included as the 5th entry in return tuples. Returns: - A list of n 4-tuples (filename, lineno, name, code), where the code tuple - element is calculated from the corresponding elements of the input tuple. + A list of n 4-tuples or 5-tuples + (filename, lineno, name, code, [optional: func_start_lineno]), where the + code tuple element is calculated from the corresponding elements of the + input tuple. """ ret = [] - for filename, lineno, name, frame_globals in stack: + for filename, lineno, name, frame_globals, func_start_lineno in stack: linecache.checkcache(filename) line = linecache.getline(filename, lineno, frame_globals) if line: line = line.strip() else: line = None - ret.append((filename, lineno, name, line)) + if include_func_start_lineno: + ret.append((filename, lineno, name, line, func_start_lineno)) + else: + ret.append((filename, lineno, name, line)) return ret @@ -103,7 +111,8 @@ def _extract_stack(): be formatted etc. using traceback methods. Returns: - A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to + A list of 5-tuples + (filename, lineno, name, frame_globals, func_start_lineno) corresponding to the call stack of the current thread. """ # pylint: enable=line-too-long @@ -118,7 +127,8 @@ def _extract_stack(): filename = co.co_filename name = co.co_name frame_globals = f.f_globals - ret.append((filename, lineno, name, frame_globals)) + func_start_lineno = co.co_firstlineno + ret.append((filename, lineno, name, frame_globals, func_start_lineno)) f = f.f_back ret.reverse() return ret @@ -1505,6 +1515,15 @@ class Operation(object): """Returns the call stack from when this operation was constructed.""" return _convert_stack(self._traceback) + @property + def traceback_with_start_lines(self): + """Same as traceback but includes start line of function definition. + + Returns: + A list of 5-tuples (filename, lineno, name, code, func_start_lineno). + """ + return _convert_stack(self._traceback, include_func_start_lineno=True) + def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 06d03121a0f..3e9f047a7de 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -22,6 +22,7 @@ import gc import weakref from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -1703,5 +1704,26 @@ class NameScopeTest(test_util.TensorFlowTestCase): self.assertEqual("", g.get_name_scope()) +class TracebackTest(test_util.TensorFlowTestCase): + + def testTracebackWithStartLines(self): + with self.test_session() as sess: + a = constant_op.constant(2.0) + sess.run( + a, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(sess.graph.get_operations()) + + # Tests that traceback_with_start_lines is the same as traceback + # but includes one more element at the end. + for op in sess.graph.get_operations(): + self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) + for frame, frame_with_start_line in zip( + op.traceback, op.traceback_with_start_lines): + self.assertEquals(5, len(frame_with_start_line)) + self.assertEquals(frame, frame_with_start_line[:-1]) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index 3630adc9549..50a07952004 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -249,6 +249,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "identity_bijector_test", + size = "small", + srcs = ["identity_bijector_test.py"], + additional_deps = [ + "//tensorflow/python/ops/distributions", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py similarity index 84% rename from tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py rename to tensorflow/python/kernel_tests/distributions/identity_bijector_test.py index 0969c293d40..e8f9d0b728d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py +++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency -from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity +from tensorflow.python.ops.distributions import bijector_test_util +from tensorflow.python.ops.distributions import identity_bijector from tensorflow.python.platform import test @@ -28,7 +28,7 @@ class IdentityBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - bijector = Identity() + bijector = identity_bijector.Identity() self.assertEqual("identity", bijector.name) x = [[[0.], [1.]]] self.assertAllEqual(x, bijector.forward(x).eval()) @@ -38,8 +38,8 @@ class IdentityBijectorTest(test.TestCase): def testScalarCongruency(self): with self.test_session(): - bijector = Identity() - assert_scalar_congruency( + bijector = identity_bijector.Identity() + bijector_test_util.assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index 80991751860..a0bd178e247 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase): sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval() + def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self): + # Note: use_gpu=False because nice errors are only returned from CPU kerne + if not test.is_gpu_available(): + return + with self.test_session(use_gpu=True): + indices = np.array([[1, 10]]).astype(np.int64) + values = np.array([10]).astype(np.float32) + shape = [3, 2] + sparse_t = sparse_tensor.SparseTensor(indices, values, shape) + + # Test multiplying by both a small and large dense matrix, to hit + # both cases in the kernel. + dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) + expected_t = np.array( + [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + + # Repeat with adjoint_a, now the error is that the sparse index + # is OOO w.r.t. the output. The GPU kernel can't do much here, + # so it just doesn't accumulate. + + dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + + dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) + expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + # Tests setting one dimension to be a high value. def _testLarge(self, np_dtype): r1 = np.random.randint(6000, 20000) @@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase): y = _maybe_complex(np.random.randn(k, n).astype(np_dtype)) - self._testMatmul(x, y) + self._testMatmul(x, y, adjoint_a=False, adjoint_b=False) + self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False) + self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True) + self._testMatmul( + x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True) - def testLarge(self): np.random.seed(127) # Repeatable results self._testLarge(np.float32) self._testLarge(np.float64) @@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, if skip_dense: delta_dense = float("nan") else: - with session.Session("", config=config, graph=ops.Graph()) as sess: + with session.Session(config=config, graph=ops.Graph()) as sess: if not use_gpu: with ops.device("/cpu:0"): x_t = constant_op.constant(x) @@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( x_t, y_t, adjoint_a, adjoint_b) else: - x_t = constant_op.constant(x) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t, - adjoint_a, - adjoint_b) - delta_dense = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_t = constant_op.constant(x) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( + x_t, y_t, adjoint_a, adjoint_b) + delta_dense = _timer(sess, ops_fn, 200) # Using sparse_tensor_dense_matmul. with session.Session("", config=config, graph=ops.Graph()) as sess: @@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) else: - x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) - x_val = constant_op.constant(x[np.where(x)]) - x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( - x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) - delta_sparse = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) + x_val = constant_op.constant(x[np.where(x)]) + x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( + x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) + delta_sparse = _timer(sess, ops_fn, 200) print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" % (1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse, @@ -340,7 +389,7 @@ def main(_): "\t dt(sparse)/dt(dense)") for thresh in (0.99, 0.8, 0.5, 0.2): - for n in (1, 10, 25): + for n in (50, 100): for use_gpu in (True, False): for m in (100, 1000): for k in (100, 1000): diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 95e803e2aa0..9a208613add 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -38,7 +38,6 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * # pylint: enable=wildcard-import -from tensorflow.python.util.deprecation import deprecated def _as_type_list(dtypes): @@ -1037,47 +1036,6 @@ class Barrier(object): self._barrier_ref, name=name) -@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") -def initialize_all_tables(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - return tables_initializer(name) - - -def tables_initializer(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) - if initializers: - return control_flow_ops.group(*initializers, name=name) - return control_flow_ops.no_op(name=name) - - -ops.NotDifferentiable("LookupTableFind") -ops.NotDifferentiable("LookupTableInsert") -ops.NotDifferentiable("LookupTableSize") -ops.NotDifferentiable("HashTable") -ops.NotDifferentiable("InitializeTable") -ops.NotDifferentiable("InitializeTableFromTextFile") -ops.NotDifferentiable("MutableDenseHashTable") -ops.NotDifferentiable("MutableHashTable") -ops.NotDifferentiable("MutableHashTableOfTensors") - - class ConditionalAccumulatorBase(object): """A conditional accumulator for aggregating gradients. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py b/tensorflow/python/ops/distributions/bijector_test_util.py similarity index 100% rename from tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py rename to tensorflow/python/ops/distributions/bijector_test_util.py diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py b/tensorflow/python/ops/distributions/identity_bijector.py similarity index 100% rename from tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py rename to tensorflow/python/ops/distributions/identity_bijector.py diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py similarity index 99% rename from tensorflow/contrib/distributions/python/ops/transformed_distribution.py rename to tensorflow/python/ops/distributions/transformed_distribution.py index e146e20d3ac..09b26a9fb73 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/python/ops/distributions/transformed_distribution.py @@ -21,7 +21,6 @@ import numpy as np # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +31,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import identity_bijector from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -265,7 +265,7 @@ class TransformedDistribution(distribution_lib.Distribution): self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") if bijector is None: - bijector = Identity(validate_args=validate_args) + bijector = identity_bijector.Identity(validate_args=validate_args) # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py new file mode 100644 index 00000000000..54dba9e38eb --- /dev/null +++ b/tensorflow/python/ops/lookup_ops.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================== +"""Data Flow Operations.""" +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_lookup_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated + + +@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") +def initialize_all_tables(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + return tables_initializer(name) + + +def tables_initializer(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) + if initializers: + return control_flow_ops.group(*initializers, name=name) + return control_flow_ops.no_op(name=name) + + +ops.NotDifferentiable("LookupTableFind") +ops.NotDifferentiable("LookupTableFindV2") +ops.NotDifferentiable("LookupTableInsert") +ops.NotDifferentiable("LookupTableInsertV2") +ops.NotDifferentiable("LookupTableSize") +ops.NotDifferentiable("LookupTableSizeV2") +ops.NotDifferentiable("HashTable") +ops.NotDifferentiable("HashTableV2") +ops.NotDifferentiable("InitializeTable") +ops.NotDifferentiable("InitializeTableV2") +ops.NotDifferentiable("InitializeTableFromTextFile") +ops.NotDifferentiable("InitializeTableFromTextFileV2") +ops.NotDifferentiable("MutableDenseHashTable") +ops.NotDifferentiable("MutableDenseHashTableV2") +ops.NotDifferentiable("MutableHashTable") +ops.NotDifferentiable("MutableHashTableV2") +ops.NotDifferentiable("MutableHashTableOfTensors") +ops.NotDifferentiable("MutableHashTableOfTensorsV2") diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 4dc8e702ca3..28ed3af9d73 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1924,7 +1924,74 @@ def recall_at_k(labels, labels = _maybe_expand_labels(labels, predictions) _, top_k_idx = nn.top_k(predictions, k) - top_k_idx = math_ops.to_int64(top_k_idx) + return _sparse_recall_at_top_k( + labels=labels, + predictions_idx=top_k_idx, + k=k, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=scope) + + +def _sparse_recall_at_top_k(labels, + predictions_idx, + k=None, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + Differs from `recall_at_k` in that predictions must be in the form of top `k` + class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` + for more details. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies + num_labels=1. N >= 1 and num_labels is the number of target classes for + the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range always count + towards `false_negative_at_`. + predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. + Commonly, N=1 and predictions has shape [batch size, k]. The final + dimension contains the top `k` predicted class indices. [D1, ... DN] must + match `labels`. + k: Integer, k for @k metric. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + with ops.name_scope(name, + _at_k_name('recall', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: + top_k_idx = math_ops.to_int64(predictions_idx) tp, tp_update = _streaming_sparse_true_positive_at_k( predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, weights=weights) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 4810e97b367..c7ac742b5d9 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -28,6 +28,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util import nest @@ -75,11 +77,13 @@ def _zero_state_tensors(state_size, batch_size, dtype): return zeros -class _RNNCell(base_layer.Layer): # pylint: disable=protected-access +class _RNNCell(base_layer.Layer): """Abstract object representing an RNN cell. - Every `RNNCell` must have the properties below and implement `__call__` with - the following signature. + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. The optional + third input argument, `scope`, is allowed for backwards compatibility + purposes; but should be left off for new subclasses. This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. @@ -90,8 +94,9 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access This operation results in an output matrix with `self.output_size` columns. If `self.state_size` is an integer, this operation also results in a new state matrix with `self.state_size` columns. If `self.state_size` is a - tuple of integers, then it results in a tuple of `len(state_size)` state - matrices, each with a column size corresponding to values in `state_size`. + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. """ def __call__(self, inputs, state, scope=None): @@ -112,7 +117,25 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state`. """ - return super(_RNNCell, self).__call__(inputs, state, scope=scope) + if scope is not None: + with vs.variable_scope(scope, + custom_getter=self._rnn_get_variable) as scope: + return super(_RNNCell, self).__call__(inputs, state, scope=scope) + else: + with vs.variable_scope(vs.get_variable_scope(), + custom_getter=self._rnn_get_variable): + return super(_RNNCell, self).__call__(inputs, state) + + def _rnn_get_variable(self, getter, *args, **kwargs): + variable = getter(*args, **kwargs) + trainable = (variable in tf_variables.trainable_variables() or + (isinstance(variable, tf_variables.PartitionedVariable) and + list(variable)[0] in tf_variables.trainable_variables())) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable @property def state_size(self): @@ -128,6 +151,11 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError("Abstract method") + def build(self, _): + # This tells the parent Layer object that it's OK to call + # self.add_variable() inside the call() method. + pass + def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 0140a27aaa7..d6cb7c5be49 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -241,6 +241,8 @@ def sparse_add(a, b, thresh=0): of arguments does not matter. Use vanilla `tf.add()` for adding two dense `Tensor`s. + The shapes of the two operands must match: broadcasting is not supported. + The indices of any input `SparseTensor` are assumed ordered in standard lexicographic order. If this is not the case, before this step run `SparseReorder` to restore index ordering. diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 09e04d4247c..a39d28490cc 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -57,6 +57,7 @@ from tensorflow.python.ops.io_ops import * from tensorflow.python.ops.linalg_ops import * from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op +from tensorflow.python.ops.lookup_ops import * from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py index 66cf9d4d8af..355fd57bf1d 100644 --- a/tensorflow/python/saved_model/main_op_impl.py +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables @@ -35,7 +35,7 @@ def main_op(): """ init = variables.global_variables_initializer() init_local = variables.local_variables_initializer() - init_tables = tf_data_flow_ops.tables_initializer() + init_tables = lookup_ops.tables_initializer() return control_flow_ops.group(init, init_local, init_tables) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 2fea29d961e..c9c56a50143 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -14,68 +14,8 @@ # ============================================================================== """Command-line interface to inspect and execute a graph in a SavedModel. -If TensorFlow is installed on your system through pip, the 'saved_model_cli' -binary can be invoked directly from command line. - -At a high level, SavedModel CLI allows users to both inspect and execute -computations on a MetaGraphDef in a SavedModel. These are done through `show` -and `run` commands. Following is the usage of the two commands. SavedModel -CLI will also display these information with -h option. - -'show' command usage: saved_model_cli show [-h] --dir DIR [--tag_set TAG_SET] - [--signature_def SIGNATURE_DEF_KEY] -Examples: -To show all available tag-sets in the SavedModel: - $saved_model_cli show --dir /tmp/saved_model - -To show all available SignatureDef keys in a MetaGraphDef specified by its -tag-set: - $saved_model_cli show --dir /tmp/saved_model --tag_set serve -For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed -in, separated by ',': - $saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu - -To show all inputs and outputs TensorInfo for a specific SignatureDef specified -by the SignatureDef key in a MetaGraphDef: - $saved_model_cli show --dir /tmp/saved_model --tag_set serve - --signature_def serving_default -Example output: - The given SavedModel SignatureDef contains the following input(s): - inputs['input0'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - inputs['input1'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - The given SavedModel SignatureDef contains the following output(s): - outputs['output'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - Method name is: tensorflow/serving/regress - -To show all available information in the SavedModel: - $saved_model_cli show --dir /tmp/saved_model --all - -usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def - SIGNATURE_DEF_KEY [--inputs INPUTS] - [--input_exprs INPUT_EXPRS] [--outdir OUTDIR] - [--overwrite] [--tf_debug] - -Examples: -To run input tensors from files through a MetaGraphDef and save the output -tensors to files: - $saved_model_cli run --dir /tmp/saved_model --tag_set serve - --signature_def serving_default --inputs x=/tmp/124.npz - --input_exprs 'x2=np.ones((6,2))' --outdir /tmp/out - -To observe the intermediate Tensor values in the runtime graph, use the ---tf_debug flag, e.g.: - $saved_model_cli run --dir /tmp/saved_model --tag_set serve - --signature_def serving_default --inputs 'x=/tmp/124.npz;x2=/tmp/123.npy' - --outdir /tmp/out --tf_debug - -To build this tool from source, run: - $bazel build tensorflow/python/tools:saved_model_cli +For detailed usages and examples, please refer to: +https://www.tensorflow.org/programmers_guide/saved_model_cli """ diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index c9056723130..a891bae5f23 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -238,7 +238,7 @@ class Scaffold(object): @staticmethod def _default_local_init_op(): return control_flow_ops.group(variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) def MonitoredTrainingSession(master='', # pylint: disable=invalid-name diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py index 5f31e2aa539..6a73565f82b 100644 --- a/tensorflow/python/training/saver_test_utils.py +++ b/tensorflow/python/training/saver_test_utils.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.training import saver as saver_module @@ -34,7 +34,7 @@ class CheckpointedOp(object): # pylint: disable=protected-access def __init__(self, name, table_ref=None): if table_ref is None: - self.table_ref = gen_data_flow_ops._mutable_hash_table( + self.table_ref = gen_lookup_ops._mutable_hash_table( key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) else: self.table_ref = table_ref @@ -52,10 +52,10 @@ class CheckpointedOp(object): return self._saveable def insert(self, keys, values): - return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values) + return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values) def lookup(self, keys, default): - return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default) + return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default) def keys(self): return self._export()[0] @@ -64,8 +64,8 @@ class CheckpointedOp(object): return self._export()[1] def _export(self): - return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string, - dtypes.float32) + return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string, + dtypes.float32) class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): """A custom saveable for CheckpointedOp.""" @@ -81,6 +81,6 @@ class CheckpointedOp(object): super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) def restore(self, restore_tensors, shapes): - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op.table_ref, restore_tensors[0], restore_tensors[1]) # pylint: enable=protected-access diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index d2ccf37d885..2091eca0b9c 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors @@ -276,14 +277,14 @@ class ClusterSpec(object): "from integers to strings." % job_name) self._cluster_spec[job_name] = job_tasks self._make_cluster_def() - elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): + elif isinstance(cluster, cluster_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items()} elif isinstance(cluster, ClusterSpec): - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: @@ -440,7 +441,7 @@ class ClusterSpec(object): TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, tasks in sorted(self._cluster_spec.items()): diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 277c11386dd..230ed1db687 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as _summary @@ -426,8 +426,10 @@ class Supervisor(object): local_init_op = self._get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index bdf3d9c0175..f4ac3c97587 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import * # pylint: enable=wildcard-import # Distributed computing support. -from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef -from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.cluster_pb2 import JobDef from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import Server @@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server _allowed_symbols = [ # TODO(cwhipkey): review these and move to contrib or expose through # documentation. - "generate_checkpoint_state_proto", # Used internally by saver. + "generate_checkpoint_state_proto", # Used internally by saver. "checkpoint_exists", # Only used in test? "get_checkpoint_mtimes", # Only used in test? # Legacy: remove. "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef. - # No uses within tensorflow. + # No uses within tensorflow. "queue_runner", # Use tf.train.start_queue_runner etc directly. - # This is also imported internally. + # This is also imported internally. # TODO(drpng): document these. The reference in howtos/distributed does # not link. "SyncReplicasOptimizer", # Protobufs: - "BytesList", # from example_pb2. + "BytesList", # from example_pb2. "ClusterDef", - "Example", # from example_pb2 - "Feature", # from example_pb2 - "Features", # from example_pb2 - "FeatureList", # from example_pb2 - "FeatureLists", # from example_pb2 - "FloatList", # from example_pb2. - "Int64List", # from example_pb2. + "Example", # from example_pb2 + "Feature", # from example_pb2 + "Features", # from example_pb2 + "FeatureList", # from example_pb2 + "FeatureLists", # from example_pb2 + "FloatList", # from example_pb2. + "Int64List", # from example_pb2. "JobDef", - "SaverDef", # From saver_pb2. - "SequenceExample", # from example_pb2. + "SaverDef", # From saver_pb2. + "SequenceExample", # from example_pb2. "ServerDef", ] # Include extra modules for docstrings because: diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json index 69f08495a30..d424f103dd7 100644 --- a/tensorflow/tensorboard/package.json +++ b/tensorflow/tensorboard/package.json @@ -30,7 +30,7 @@ "merge2": "~0.3.6", "minimist": "~1.2.0", "tsify": "^0.14.8", - "typescript": "2.2.2", + "typescript": "2.3.1", "typings": "1.4.0", "vinyl-source-stream": "^1.1.0", "vulcanize": "^1.14.0", diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 805a9bdd4f1..da6af3919e9 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" mtype: "" } + member { + name: "CLUSTER_DEF_FIELD_NUMBER" + mtype: "" + } member { name: "DESCRIPTOR" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt index 0f43a49ee96..64240f70698 100644 --- a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt @@ -38,6 +38,10 @@ tf_class { name: "traceback" mtype: "" } + member { + name: "traceback_with_start_lines" + mtype: "" + } member { name: "type" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt index feb73bd7d4f..93ff856b09d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.ClusterDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt index 2d7fcbe5456..ac6d81541a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef.TasksEntry" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt index fc5b76341d2..ce34537fa13 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 61c3fe55405..459d6ee3284 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -28,11 +28,13 @@ import subprocess PIP_PACKAGE_QUERY = """bazel query \ 'deps(//tensorflow/tools/pip_package:build_pip_package)'""" -PY_TEST_QUERY = """bazel query 'filter("^((?!(benchmark|manual|no_pip)).)*$", \ - deps(kind(py_test,\ - //tensorflow/python/... + \ - //tensorflow/tensorboard/... + \ - //tensorflow/contrib/...), 1))'""" +PY_TEST_QUERY = """bazel query 'deps(\ + filter("^((?!benchmark).)*$",\ + kind(py_test,\ + //tensorflow/python/... \ + + //tensorflow/tensorboard/... \ + + //tensorflow/contrib/... \ + - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'""" # Hard-coded blacklist of files if not included in pip package # TODO(amitpatankar): Clean up blacklist. @@ -45,6 +47,7 @@ BLACKLIST = [ "//tensorflow/python:compare_test_proto_py", "//tensorflow/core:image_testdata", "//tensorflow/core/kernels/cloud:bigquery_reader_ops", + "//tensorflow/python/feature_column:vocabulary_testdata", "//tensorflow/python:framework/test_file_system.so", # contrib "//tensorflow/contrib/session_bundle:session_bundle_half_plus_two", @@ -54,7 +57,7 @@ BLACKLIST = [ "//tensorflow/contrib/factorization/examples:mnist.py", "//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long "//tensorflow/contrib/bayesflow:reinforce_simple_example", - "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long + "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long ] @@ -121,7 +124,10 @@ def main(): affected_tests_list = affected_tests.split("\n")[:-2] print("\n".join(affected_tests_list)) - raise RuntimeError("One or more dependencies are not in the pip package.") + raise RuntimeError("""One or more dependencies are not in the pip package. +Please either blacklist the dependencies in +tensorflow/tensorflow/tensorflow/tools/pip_package/pip_smoke_test.py +or add them to tensorflow/tensorflow/tensorflow/tools/pip_package/BUILD.""") else: print("TEST PASSED") diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 6270b95b6b3..3831a481bad 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -687,13 +687,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_microsoft_typescript", licenses = ["notice"], # Apache 2.0 sha256_urls = { - "43a7c763fe024d5add8d5365e5a7981f4a359ba5bf86481f545a0db8f60d48cc": [ - "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js", + "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ + "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", ], - "aecec1e47a3b3d872e214cb9adb82b30d6bd0471ea0aad7311ad81428566627c": [ - "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts", + "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ + "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", ], }, extra_build_file_content = "\n".join([