Merge commit for internal changes
This commit is contained in:
commit
19ff8745b6
15
configure
vendored
15
configure
vendored
@ -49,6 +49,15 @@ while true; do
|
||||
# Retry
|
||||
done
|
||||
|
||||
## Set up architecture-dependent optimization flags.
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
default_cc_opt_flags="-march=native"
|
||||
read -p "Please specify optimization flags to use during compilation [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
CC_OPT_FLAGS=$default_cc_opt_flags
|
||||
fi
|
||||
fi
|
||||
|
||||
if is_windows; then
|
||||
TF_NEED_GCP=0
|
||||
TF_NEED_HDFS=0
|
||||
@ -148,6 +157,12 @@ fi
|
||||
# Invoke python_config and set up symlinks to python includes
|
||||
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
|
||||
|
||||
# Append CC optimization flags to bazel.rc
|
||||
echo >> tools/bazel.rc
|
||||
for opt in $CC_OPT_FLAGS; do
|
||||
echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc
|
||||
done
|
||||
|
||||
# Run the gen_git_source to create links where bazel can track dependencies for
|
||||
# git hash propagation
|
||||
GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
|
||||
|
||||
@ -110,5 +109,4 @@ Input::Initializer::Initializer(
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
class Output;
|
||||
|
||||
@ -193,6 +192,7 @@ class Input {
|
||||
// * A scalar, or a multi-dimensional tensor specified as a recursive
|
||||
// initializer list. This enables directly passing constants as
|
||||
// inputs to op wrappers.
|
||||
// * A Tensor object.
|
||||
Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit)
|
||||
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
@ -249,7 +249,7 @@ typedef std::vector<Output> OutputList;
|
||||
class InputList {
|
||||
public:
|
||||
// Implicitly convert a list of outputs to a list of inputs. This is useful to
|
||||
// write code such as tf.Concat(tf.Split(x, 4)).
|
||||
// write code such as ops::Concat(ops::Split(x, 4)).
|
||||
InputList(const OutputList& out) { // NOLINT(runtime/explicit)
|
||||
for (auto const& x : out) {
|
||||
inputs_.push_back(x);
|
||||
@ -284,7 +284,19 @@ class InputList {
|
||||
std::vector<Input> inputs_;
|
||||
};
|
||||
|
||||
// These symbols used to live in the ops namespace, so we temporarily
|
||||
// declare some aliases there. TODO(josh11b): Delete this!
|
||||
namespace ops {
|
||||
|
||||
using ::tensorflow::Input;
|
||||
using ::tensorflow::InputList;
|
||||
using ::tensorflow::Operation;
|
||||
using ::tensorflow::Output;
|
||||
using ::tensorflow::OutputHash;
|
||||
using ::tensorflow::OutputList;
|
||||
|
||||
} // namespace ops
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
|
||||
|
@ -602,12 +602,10 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
|
||||
std::vector<int>* output_permutation, NodeDef* node) {
|
||||
// Optimize the subgraph.
|
||||
Graph* g = subgraph->release();
|
||||
OptimizeGraph(flr.get(), &g);
|
||||
subgraph->reset(g);
|
||||
OptimizeGraph(flr.get(), subgraph);
|
||||
|
||||
std::vector<bool> const_args(input_permutation->size());
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args));
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
|
||||
|
||||
// Compute a permutation of the arguments such that the constant arguments
|
||||
// are first.
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
// Run tests, comparing the Tensorflow CPU operators with their XLA-compiled
|
||||
// counterparts:
|
||||
// randomized_tests \
|
||||
// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU \
|
||||
// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU:0 \
|
||||
// --tf_xla_test_repetitions=20
|
||||
|
||||
// TODO(phawkins): add tests for:
|
||||
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -66,6 +67,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -76,9 +78,8 @@ int32 tf_xla_test_repetitions = 20;
|
||||
string* tf_xla_test_device_ptr; // initial value set in main()
|
||||
bool tf_xla_test_use_jit = true;
|
||||
|
||||
string DeviceTypeToDeviceName(DeviceType type) {
|
||||
return strings::StrCat("/job:localhost/replica:0/task:0/device:", type.type(),
|
||||
":0");
|
||||
string LocalDeviceToFullDeviceName(const string& device) {
|
||||
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
||||
}
|
||||
|
||||
constexpr std::array<DataType, 3> kAllXlaTypes = {
|
||||
@ -575,9 +576,14 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
|
||||
|
||||
void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
|
||||
double atol, double rtol) {
|
||||
string cpu_device = DeviceTypeToDeviceName(DEVICE_CPU);
|
||||
DeviceType test_device_type(*tf_xla_test_device_ptr);
|
||||
string test_device = DeviceTypeToDeviceName(test_device_type);
|
||||
string cpu_device =
|
||||
LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
|
||||
string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
|
||||
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
ASSERT_TRUE(
|
||||
DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name));
|
||||
DeviceType test_device_type(parsed_name.type);
|
||||
++num_tests_;
|
||||
|
||||
GraphDef graph;
|
||||
@ -2058,7 +2064,7 @@ TEST_F(OpTest, ZerosLike) {
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU");
|
||||
tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0");
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag(
|
||||
"tf_xla_random_seed", &tensorflow::tf_xla_random_seed,
|
||||
@ -2085,13 +2091,18 @@ int main(int argc, char** argv) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||
return 2;
|
||||
}
|
||||
// XLA devices register kernels at construction time; create and destroy all
|
||||
// known devices to make sure the kernels are registered.
|
||||
// XLA devices register kernels at construction time; create all known devices
|
||||
// to make sure the kernels are registered.
|
||||
std::vector<tensorflow::Device*> devices;
|
||||
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
||||
tensorflow::SessionOptions(), "", &devices));
|
||||
for (tensorflow::Device* device : devices) {
|
||||
delete device;
|
||||
}
|
||||
tensorflow::DeviceMgr device_mgr(devices);
|
||||
|
||||
tensorflow::Device* ignored;
|
||||
TF_QCHECK_OK(
|
||||
device_mgr.LookupDevice(*tensorflow::tf_xla_test_device_ptr, &ignored))
|
||||
<< "Unknown test device (" << *tensorflow::tf_xla_test_device_ptr
|
||||
<< "). Did you build in the right configuration (e.g., is CUDA enabled)?";
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -97,7 +97,6 @@ cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -186,9 +186,7 @@ Status XlaCompiler::CompileFunctionBody(
|
||||
// for devices other than CPU.
|
||||
OptimizerOptions opts;
|
||||
GraphOptimizer optimizer(opts);
|
||||
Graph* g = graph.release();
|
||||
OptimizeGraph(flr, &g);
|
||||
graph.reset(g);
|
||||
OptimizeGraph(flr, &graph);
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
|
@ -51,18 +51,9 @@ bool IsLiteralWithValue(const HloInstruction* operand, int value) {
|
||||
|
||||
// Returns whether the given transpose produces a result which is bit-wise
|
||||
// identical to its operand and thus may be replaced with a bitcast.
|
||||
bool TransposeIsBitcast(
|
||||
const HloInstruction* transpose,
|
||||
const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
|
||||
bool TransposeIsBitcast(const HloInstruction* transpose) {
|
||||
CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
|
||||
const HloInstruction* operand = transpose->operand(0);
|
||||
|
||||
// Can't insert bitcasts if the compiler used a memory layout which isn't
|
||||
// compatible.
|
||||
if (!valid_bitcast_callback(operand->shape(), transpose->shape())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
|
||||
transpose->dimensions());
|
||||
}
|
||||
@ -80,11 +71,8 @@ bool ReshapeIsBitcast(
|
||||
const HloInstruction* operand = reshape->operand(0);
|
||||
// Can't insert bitcasts if the compiler used a memory layout which isn't
|
||||
// compatible.
|
||||
if (!valid_bitcast_callback(operand->shape(), reshape->shape())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape());
|
||||
return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
|
||||
valid_bitcast_callback(operand->shape(), reshape->shape());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -199,7 +187,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
// Whether layout is considered during transformation.
|
||||
bool is_layout_sensitive_;
|
||||
|
||||
// Callback used to determine if a bitcast is valid.
|
||||
// Callback used to determine if a bitcast is possible.
|
||||
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_;
|
||||
};
|
||||
|
||||
@ -287,7 +275,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
|
||||
HloInstruction* rhs) {
|
||||
// A/1 => A
|
||||
VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
|
||||
if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
|
||||
if (IsLiteralWithValue(rhs, 1) &&
|
||||
ReplaceInstructionIfSameShape(divide, lhs)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -717,8 +706,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (is_layout_sensitive_ &&
|
||||
TransposeIsBitcast(transpose, valid_bitcast_callback_)) {
|
||||
if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
|
||||
ReplaceWithBitcast(transpose);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -26,9 +26,11 @@ namespace xla {
|
||||
// A pass which performs AlgebraicSimplications.
|
||||
class AlgebraicSimplifier : public HloPass {
|
||||
public:
|
||||
// Given two shapes, determines if it is valid to bitcast between them.
|
||||
// Precondition: the two shapes have layouts and have the same number of
|
||||
// elements.
|
||||
// Given two shapes, determines if it is valid to bitcast between them after
|
||||
// considering platform dependent effects on layout like alignment
|
||||
// restrictions.
|
||||
// Precondition: the two shapes have layouts, the same number of
|
||||
// elements and ShapeUtil::ReshapeIsBitcast returns true.
|
||||
using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>;
|
||||
|
||||
// If is_layout_sensitive is true, then the simplifier preserves layout during
|
||||
|
@ -495,11 +495,13 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
||||
}
|
||||
|
||||
void DumpText(const HloModule& module, const string& label,
|
||||
const string& directory_path) {
|
||||
const string& directory_path, bool do_prefix) {
|
||||
Env* env = Env::Default();
|
||||
TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
|
||||
string prefix = StrCat(env->NowMicros());
|
||||
string path = JoinPath(directory_path, StrCat(prefix, "-", label, ".txt"));
|
||||
string filename =
|
||||
do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt");
|
||||
string path = JoinPath(directory_path, filename);
|
||||
TF_CHECK_OK(WriteStringToFile(env, path, module.ToString()));
|
||||
}
|
||||
|
||||
|
@ -33,8 +33,12 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
||||
|
||||
// Dumps the HloModule::ToString() as a file into the provided directory path
|
||||
// suffixed with the provided label.
|
||||
//
|
||||
// If do_prefix is true, a timestamp will be prepended onto the label to
|
||||
// construct a filename in the directory path; otherwise, the label is used
|
||||
// as the filename directly.
|
||||
void DumpText(const HloModule& module, const string& label,
|
||||
const string& directory_path);
|
||||
const string& directory_path, bool do_prefix = true);
|
||||
|
||||
// Abstract interface for classes that render DOT graphs.
|
||||
class GraphRendererInterface {
|
||||
|
@ -27,6 +27,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
|
||||
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
|
||||
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
|
||||
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
|
||||
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
|
||||
|
||||
if (NOT WIN32)
|
||||
# Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option
|
||||
@ -67,7 +68,15 @@ if(WIN32)
|
||||
endif()
|
||||
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-fno-exceptions -std=c++11")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11")
|
||||
endif()
|
||||
|
||||
if (tensorflow_OPTIMIZE_FOR_NATIVE_ARCH)
|
||||
include(CheckCXXCompilerFlag)
|
||||
CHECK_CXX_COMPILER_FLAG("-march=native" COMPILER_OPT_ARCH_NATIVE_SUPPORTED)
|
||||
if (COMPILER_OPT_ARCH_NATIVE_SUPPORTED)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# External dependencies
|
||||
|
@ -448,6 +448,23 @@ cuda_py_tests(
|
||||
tags = ["nomsan"], # disable to avoid false positives from scipy.
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "vector_student_t_test",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/vector_student_t_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
":distributions_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "uniform_test",
|
||||
size = "small",
|
||||
|
@ -93,6 +93,11 @@ representing the posterior or posterior predictive.
|
||||
|
||||
@@kl
|
||||
@@RegisterKL
|
||||
|
||||
## Utilities
|
||||
|
||||
@@softplus_inverse
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -110,6 +115,7 @@ from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
||||
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
|
||||
from tensorflow.contrib.distributions.python.ops.exponential import *
|
||||
from tensorflow.contrib.distributions.python.ops.gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
@ -28,8 +29,11 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -479,5 +483,96 @@ class GenNewSeedTest(test.TestCase):
|
||||
self.assertTrue(distribution_util.gen_new_seed(None, "salt") is None)
|
||||
|
||||
|
||||
# TODO(jvdillon): Merge this test back into:
|
||||
# tensorflow/python/kernel_tests/softplus_op_test.py
|
||||
# once TF core is accepting new ops.
|
||||
class SoftplusTest(test.TestCase):
|
||||
|
||||
def _npSoftplus(self, np_features):
|
||||
np_features = np.asarray(np_features)
|
||||
zero = np.asarray(0).astype(np_features.dtype)
|
||||
return np.logaddexp(zero, np_features)
|
||||
|
||||
def _testSoftplus(self, np_features, use_gpu=False):
|
||||
np_features = np.asarray(np_features)
|
||||
np_softplus = self._npSoftplus(np_features)
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
softplus = nn_ops.softplus(np_features)
|
||||
softplus_inverse = distribution_util.softplus_inverse(softplus)
|
||||
[tf_softplus, tf_softplus_inverse] = sess.run([
|
||||
softplus, softplus_inverse])
|
||||
self.assertAllCloseAccordingToType(np_softplus, tf_softplus)
|
||||
rtol = {"float16": 0.07, "float32": 0.003, "float64": 0.002}.get(
|
||||
str(np_features.dtype), 1e-6)
|
||||
# This will test that we correctly computed the inverse by verifying we
|
||||
# recovered the original input.
|
||||
self.assertAllCloseAccordingToType(
|
||||
np_features, tf_softplus_inverse,
|
||||
atol=0., rtol=rtol)
|
||||
self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool),
|
||||
tf_softplus > 0)
|
||||
|
||||
self.assertShapeEqual(np_softplus, softplus)
|
||||
self.assertShapeEqual(np_softplus, softplus_inverse)
|
||||
|
||||
self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool),
|
||||
np.isfinite(tf_softplus))
|
||||
self.assertAllEqual(np.ones_like(tf_softplus_inverse).astype(np.bool),
|
||||
np.isfinite(tf_softplus_inverse))
|
||||
|
||||
def testNumbers(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
lower = {np.float16: -15, np.float32: -50, np.float64: -50}.get(t, -100)
|
||||
upper = {np.float16: 50, np.float32: 50, np.float64: 50}.get(t, 100)
|
||||
self._testSoftplus(
|
||||
np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape(
|
||||
[2, -1]),
|
||||
use_gpu=False)
|
||||
self._testSoftplus(
|
||||
np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape(
|
||||
[2, -1]),
|
||||
use_gpu=True)
|
||||
log_eps = np.log(np.finfo(t).eps)
|
||||
one = t(1)
|
||||
ten = t(10)
|
||||
self._testSoftplus(
|
||||
[
|
||||
log_eps, log_eps - one, log_eps + one, log_eps - ten,
|
||||
log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
|
||||
-log_eps - ten, -log_eps + ten
|
||||
],
|
||||
use_gpu=False)
|
||||
self._testSoftplus(
|
||||
[
|
||||
log_eps, log_eps - one, log_eps + one, log_eps - ten,
|
||||
log_eps + ten - log_eps, -log_eps - one, -log_eps + one,
|
||||
-log_eps - ten, -log_eps + ten
|
||||
],
|
||||
use_gpu=True)
|
||||
|
||||
def testGradient(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(
|
||||
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
|
||||
shape=[2, 5],
|
||||
name="x")
|
||||
y = nn_ops.softplus(x, name="softplus")
|
||||
x_init = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float32,
|
||||
order="F")
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, [2, 5], y, [2, 5], x_init_value=x_init)
|
||||
print("softplus (float) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testInverseSoftplusGradientNeverNan(self):
|
||||
with self.test_session():
|
||||
# Note that this range contains both zero and inf.
|
||||
x = constant_op.constant((10.**np.arange(-8, 6)).astype(np.float16))
|
||||
y = distribution_util.softplus_inverse(x).eval()
|
||||
# Equivalent to `assertAllFalse` (if it existed).
|
||||
self.assertAllEqual(np.zeros_like(y).astype(np.bool), np.isnan(y))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -0,0 +1,283 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for MultivariateStudentsT Distribution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
from scipy import special
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.vector_student_t import _VectorStudentT
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class _FakeVectorStudentT(object):
|
||||
"""Fake scipy implementation for Multivariate Student's t-distribution.
|
||||
|
||||
Technically we don't need to test the `Vector Student's t-distribution` since
|
||||
its composed of only unit-tested parts. However this _FakeVectorStudentT
|
||||
serves as something like an end-to-end test of the
|
||||
`TransformedDistribution + Affine` API.
|
||||
|
||||
Other `Vector*` implementations need only test new code. That we don't need
|
||||
to test every Vector* distribution is good because there aren't SciPy
|
||||
analogues and reimplementing everything in NumPy sort of defeats the point of
|
||||
having the `TransformedDistribution + Affine` API.
|
||||
"""
|
||||
|
||||
def __init__(self, df, shift, scale_tril):
|
||||
self._df = np.asarray(df)
|
||||
self._shift = np.asarray(shift)
|
||||
self._scale_tril = np.asarray(scale_tril)
|
||||
|
||||
def log_prob(self, x):
|
||||
def _compute(df, shift, scale_tril, x):
|
||||
k = scale_tril.shape[-1]
|
||||
ildj = np.sum(np.log(np.abs(np.diag(scale_tril))), axis=-1)
|
||||
logz = ildj + k * (0.5 * np.log(df) +
|
||||
0.5 * np.log(np.pi) +
|
||||
special.gammaln(0.5 * df) -
|
||||
special.gammaln(0.5 * (df + 1.)))
|
||||
y = linalg.solve_triangular(scale_tril, np.matrix(x - shift).T,
|
||||
lower=True, overwrite_b=True)
|
||||
logs = -0.5 * (df + 1.) * np.sum(np.log1p(y**2. / df), axis=-2)
|
||||
return logs - logz
|
||||
if not self._df.shape:
|
||||
return _compute(self._df, self._shift, self._scale_tril, x)
|
||||
return np.concatenate([
|
||||
[_compute(self._df[i], self._shift[i], self._scale_tril[i], x[:, i, :])]
|
||||
for i in range(len(self._df))]).T
|
||||
|
||||
def prob(self, x):
|
||||
return np.exp(self.log_prob(x))
|
||||
|
||||
|
||||
class VectorStudentTTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def testProbStaticScalar(self):
|
||||
with self.test_session():
|
||||
# Scalar batch_shape.
|
||||
df = np.asarray(3., dtype=np.float32)
|
||||
# Scalar batch_shape.
|
||||
shift = np.asarray([1], dtype=np.float32)
|
||||
scale_diag = np.asarray([2.], dtype=np.float32)
|
||||
scale_tril = np.diag(scale_diag)
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=df, shift=shift, scale_tril=scale_tril)
|
||||
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
x = 2. * self._rng.rand(4, 1).astype(np.float32) - 1.
|
||||
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbStatic(self):
|
||||
# Non-scalar batch_shape.
|
||||
df = np.asarray([1., 2, 3], dtype=np.float32)
|
||||
# Non-scalar batch_shape.
|
||||
shift = np.asarray([[0., 0, 0],
|
||||
[1, 2, 3],
|
||||
[1, 0, 1]],
|
||||
dtype=np.float32)
|
||||
scale_diag = np.asarray([[1., 2, 3],
|
||||
[2, 3, 4],
|
||||
[4, 5, 6]],
|
||||
dtype=np.float32)
|
||||
scale_tril = np.concatenate([[np.diag(scale_diag[i])]
|
||||
for i in range(len(scale_diag))])
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=df, shift=shift, scale_tril=scale_tril)
|
||||
|
||||
with self.test_session():
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbDynamic(self):
|
||||
# Non-scalar batch_shape.
|
||||
df = np.asarray([1., 2, 3], dtype=np.float32)
|
||||
# Non-scalar batch_shape.
|
||||
shift = np.asarray([[0., 0, 0],
|
||||
[1, 2, 3],
|
||||
[1, 0, 1]],
|
||||
dtype=np.float32)
|
||||
scale_diag = np.asarray([[1., 2, 3],
|
||||
[2, 3, 4],
|
||||
[4, 5, 6]],
|
||||
dtype=np.float32)
|
||||
scale_tril = np.concatenate([[np.diag(scale_diag[i])]
|
||||
for i in range(len(scale_diag))])
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=df, shift=shift, scale_tril=scale_tril)
|
||||
|
||||
with self.test_session():
|
||||
df_pl = array_ops.placeholder(dtypes.float32, name="df")
|
||||
shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
|
||||
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
|
||||
feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbScalarBaseDistributionNonScalarTransform(self):
|
||||
# Scalar batch_shape.
|
||||
df = np.asarray(2., dtype=np.float32)
|
||||
# Non-scalar batch_shape.
|
||||
shift = np.asarray([[0., 0, 0],
|
||||
[1, 2, 3],
|
||||
[1, 0, 1]],
|
||||
dtype=np.float32)
|
||||
scale_diag = np.asarray([[1., 2, 3],
|
||||
[2, 3, 4],
|
||||
[4, 5, 6]],
|
||||
dtype=np.float32)
|
||||
scale_tril = np.concatenate([[np.diag(scale_diag[i])]
|
||||
for i in range(len(scale_diag))])
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=np.tile(df, len(scale_diag)),
|
||||
shift=shift,
|
||||
scale_tril=scale_tril)
|
||||
|
||||
with self.test_session():
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbScalarBaseDistributionNonScalarTransformDynamic(self):
|
||||
# Scalar batch_shape.
|
||||
df = np.asarray(2., dtype=np.float32)
|
||||
# Non-scalar batch_shape.
|
||||
shift = np.asarray([[0., 0, 0],
|
||||
[1, 2, 3],
|
||||
[1, 0, 1]],
|
||||
dtype=np.float32)
|
||||
scale_diag = np.asarray([[1., 2, 3],
|
||||
[2, 3, 4],
|
||||
[4, 5, 6]],
|
||||
dtype=np.float32)
|
||||
scale_tril = np.concatenate([[np.diag(scale_diag[i])]
|
||||
for i in range(len(scale_diag))])
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=np.tile(df, len(scale_diag)),
|
||||
shift=shift,
|
||||
scale_tril=scale_tril)
|
||||
|
||||
with self.test_session():
|
||||
df_pl = array_ops.placeholder(dtypes.float32, name="df")
|
||||
shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
|
||||
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
|
||||
feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbNonScalarBaseDistributionScalarTransform(self):
|
||||
# Non-scalar batch_shape.
|
||||
df = np.asarray([1., 2., 3.], dtype=np.float32)
|
||||
# Scalar batch_shape.
|
||||
shift = np.asarray([1, 2, 3], dtype=np.float32)
|
||||
scale_diag = np.asarray([2, 3, 4], dtype=np.float32)
|
||||
scale_tril = np.diag(scale_diag)
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=df,
|
||||
shift=np.tile(shift[None, :], [len(df), 1]),
|
||||
scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1]))
|
||||
|
||||
with self.test_session():
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
def testProbNonScalarBaseDistributionScalarTransformDynamic(self):
|
||||
# Non-scalar batch_shape.
|
||||
df = np.asarray([1., 2., 3.], dtype=np.float32)
|
||||
# Scalar batch_shape.
|
||||
shift = np.asarray([1, 2, 3], dtype=np.float32)
|
||||
scale_diag = np.asarray([2, 3, 4], dtype=np.float32)
|
||||
scale_tril = np.diag(scale_diag)
|
||||
|
||||
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
|
||||
|
||||
expected_mst = _FakeVectorStudentT(
|
||||
df=df,
|
||||
shift=np.tile(shift[None, :], [len(df), 1]),
|
||||
scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1]))
|
||||
|
||||
with self.test_session():
|
||||
df_pl = array_ops.placeholder(dtypes.float32, name="df")
|
||||
shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
|
||||
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
|
||||
feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
|
||||
actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
|
||||
validate_args=True)
|
||||
self.assertAllClose(expected_mst.log_prob(x),
|
||||
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
self.assertAllClose(expected_mst.prob(x),
|
||||
actual_mst.prob(x).eval(feed_dict=feed_dict),
|
||||
rtol=0., atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -53,10 +53,12 @@ import contextlib
|
||||
import itertools
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
|
||||
@ -75,6 +77,7 @@ from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Affine",
|
||||
"AffineLinearOperator",
|
||||
@ -92,16 +95,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
# TODO(jvdillon): deprecate this function once tf.expm1 exists.
|
||||
def _expm1(x):
|
||||
"""Approximate exp{y}-1~=y for small |y|, and exp{y}-1 elsewhere."""
|
||||
# Recall, eps is smallest positive number such that 1 + eps != 1.
|
||||
eps = np.finfo(x.dtype.base_dtype.as_numpy_dtype).eps
|
||||
# Note we are careful to never send an NaN through ANY branch of where.
|
||||
return array_ops.where(math_ops.less(math_ops.abs(x), eps),
|
||||
x, math_ops.exp(x) - 1.)
|
||||
|
||||
|
||||
def _as_tensor(x, name):
|
||||
"""Convenience to convert to `Tensor` or leave as `None`."""
|
||||
return None if x is None else ops.convert_to_tensor(x, name=name)
|
||||
@ -1271,7 +1264,7 @@ class PowerTransform(Bijector):
|
||||
return x, ildj
|
||||
# TODO(jvdillon): If large y accuracy is an issue, consider using
|
||||
# (y**self.power - 1.) / self.power when y >> 1.
|
||||
x = _expm1(math_ops.log(y) * self.power) / self.power
|
||||
x = math_ops.expm1(math_ops.log(y) * self.power) / self.power
|
||||
ildj = (self.power - 1.) * math_ops.reduce_sum(
|
||||
math_ops.log(y),
|
||||
reduction_indices=event_dims)
|
||||
@ -1590,6 +1583,7 @@ class Affine(Bijector):
|
||||
`scale_diag` has shape [N1, N2, ... k, k], which represents a k x k
|
||||
lower triangular matrix.
|
||||
When `None` no `scale_tril` term is added to `scale`.
|
||||
The upper triangular elements above the diagonal are ignored.
|
||||
scale_perturb_factor: Numeric `Tensor` representing factor matrix with
|
||||
last two dimensions of shape `(k, r)`.
|
||||
When `None`, no rank-r update is added to `scale`.
|
||||
@ -2086,31 +2080,21 @@ class Softplus(Bijector):
|
||||
return nn_ops.softplus(x)
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, y):
|
||||
# The most stable inverse of softplus is not the most obvious one.
|
||||
# y = softplus(x) = Log[1 + exp{x}], (which means y > 0).
|
||||
# ==> exp{y} = 1 + exp{x} (1)
|
||||
# ==> x = Log[exp{y} - 1] (2)
|
||||
# = Log[(exp{y} - 1) / exp{y}] + Log[exp{y}]
|
||||
# = Log[(1 - exp{-y}) / 1] + Log[exp{y}]
|
||||
# = Log[1 - exp{-y}] + y (3)
|
||||
# (2) is the "obvious" inverse, but (3) is more stable than (2) for large y.
|
||||
# For small y (e.g. y = 1e-10), (3) will become -inf since 1 - exp{-y} will
|
||||
# be zero. To fix this, we use 1 - exp{-y} approx y for small y > 0.
|
||||
#
|
||||
# Stable inverse log det jacobian.
|
||||
if self.shaper is None:
|
||||
raise ValueError("Jacobian cannot be computed with unknown event_ndims")
|
||||
_, _, event_dims = self.shaper.get_dims(y)
|
||||
# Could also do:
|
||||
# ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y),
|
||||
# reduction_indices=event_dims)
|
||||
# but the following is more numerically stable. Ie,
|
||||
# Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1]
|
||||
# ==> dX/dY = exp{Y} / (exp{Y} - 1)
|
||||
# = 1 / (1 - exp{-Y}),
|
||||
# which is the most stable for large Y > 0. For small Y, we use
|
||||
# 1 - exp{-Y} approx Y.
|
||||
if self.shaper is None:
|
||||
raise ValueError("Jacobian cannot be computed with unknown event_ndims")
|
||||
_, _, event_dims = self.shaper.get_dims(y)
|
||||
log_one_minus_exp_neg = math_ops.log(-_expm1(-y))
|
||||
x = y + log_one_minus_exp_neg
|
||||
ildj = -math_ops.reduce_sum(
|
||||
log_one_minus_exp_neg, reduction_indices=event_dims)
|
||||
return x, ildj
|
||||
ildj = -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)),
|
||||
reduction_indices=event_dims)
|
||||
return distribution_util.softplus_inverse(y), ildj
|
||||
|
||||
def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument
|
||||
if self.shaper is None:
|
||||
|
@ -564,6 +564,63 @@ def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"):
|
||||
return y
|
||||
|
||||
|
||||
# TODO(jvdillon): Merge this test back into:
|
||||
# tensorflow/python/ops/softplus_op_test.py
|
||||
# once TF core is accepting new ops.
|
||||
def softplus_inverse(x, name=None):
|
||||
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
|
||||
|
||||
Mathematically this op is equivalent to:
|
||||
|
||||
```none
|
||||
softplus_inverse = log(exp(x) - 1.)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`. Non-negative (not enforced), floating-point.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
`Tensor`. Has the same type/shape as input `x`.
|
||||
"""
|
||||
with ops.name_scope(name, "softplus_inverse", values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
# We begin by deriving a more numerically stable softplus_inverse:
|
||||
# x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
|
||||
# ==> exp{x} = 1 + exp{y} (1)
|
||||
# ==> y = Log[exp{x} - 1] (2)
|
||||
# = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
|
||||
# = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
|
||||
# = Log[1 - exp{-x}] + x (3)
|
||||
# (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
|
||||
# For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
|
||||
# be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
|
||||
#
|
||||
# In addition to the numerically stable derivation above, we clamp
|
||||
# small/large values to be congruent with the logic in:
|
||||
# tensorflow/core/kernels/softplus_op.h
|
||||
#
|
||||
# Finally, we set the input to one whenever the input is too large or too
|
||||
# small. This ensures that no unchosen codepath is +/- inf. This is
|
||||
# necessary to ensure the gradient doesn't get NaNs. Recall that the
|
||||
# gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
|
||||
# thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
|
||||
# to overwrite `x` with ones only when we will never actually use this
|
||||
# value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
|
||||
threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
|
||||
is_too_small = math_ops.less(x, np.exp(threshold))
|
||||
is_too_large = math_ops.greater(x, -threshold)
|
||||
too_small_value = math_ops.log(x)
|
||||
too_large_value = x
|
||||
# This `where` will ultimately be a NOP because we won't select this
|
||||
# codepath whenever we used the surrogate `ones_like`.
|
||||
x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large),
|
||||
array_ops.ones_like(x), x)
|
||||
y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x))
|
||||
return array_ops.where(is_too_small, too_small_value,
|
||||
array_ops.where(is_too_large, too_large_value, y))
|
||||
|
||||
|
||||
class AppendDocstring(object):
|
||||
"""Helper class to promote private subclass docstring to public counterpart.
|
||||
|
||||
|
295
tensorflow/contrib/distributions/python/ops/vector_student_t.py
Normal file
295
tensorflow/contrib/distributions/python/ops/vector_student_t.py
Normal file
@ -0,0 +1,295 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Vector Student's t distribution classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import bijector as bijectors
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.contrib.distributions.python.ops import student_t
|
||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
# TODO(jvdillon): Add unittests for this once we know where will put this code
|
||||
# and how it will generally be used. In the interim this code is tested via the
|
||||
# _VectorStudentT tests.
|
||||
def _infer_shapes(scale_oppd, shift):
|
||||
"""Helper which returns batch_shape, event_shape from `Affine` properties.
|
||||
|
||||
The `Affine` `Bijector` (roughly) computes `Y = scale @ X.T + shift`. This
|
||||
function infers the `batch_shape` and `event_shape` from the `scale` and
|
||||
`shift` terms.
|
||||
|
||||
Args:
|
||||
scale_oppd: Instance of OperatorPDBase subclass representing the `Affine`
|
||||
`Bijector` scale matrix.
|
||||
shift: `Tensor` representing the `shift` vector.
|
||||
|
||||
Returns:
|
||||
batch_shape: 1D, integer `Tensor` representing the shape of batch
|
||||
dimensions.
|
||||
event_shape: 1D, integer `Tensor` representing the shape of event
|
||||
dimensions.
|
||||
|
||||
Raises:
|
||||
ValueError: if we are not able to infer batch/event shapes from the args.
|
||||
"""
|
||||
# Collect known static shape.
|
||||
def _has_static_ndims(x):
|
||||
return x is not None and x.get_shape().ndims is not None
|
||||
if _has_static_ndims(scale_oppd) and _has_static_ndims(shift):
|
||||
batch_shape = scale_oppd.get_batch_shape().merge_with(
|
||||
shift.get_shape()[:-1])
|
||||
event_shape = scale_oppd.get_shape()[-1:].merge_with(
|
||||
shift.get_shape()[-1:])
|
||||
elif _has_static_ndims(scale_oppd):
|
||||
batch_shape = scale_oppd.get_batch_shape()
|
||||
event_shape = scale_oppd.get_shape()[-1:]
|
||||
elif _has_static_ndims(shift):
|
||||
batch_shape = shift.get_shape()[:-1]
|
||||
event_shape = shift.get_shape()[-1:]
|
||||
else:
|
||||
batch_shape = tensor_shape.TensorShape(None)
|
||||
event_shape = tensor_shape.TensorShape(None)
|
||||
|
||||
# Convert TensorShape to Tensors and see if we're done.
|
||||
if batch_shape.is_fully_defined():
|
||||
batch_shape = constant_op.constant(batch_shape.as_list(),
|
||||
dtype=dtypes.int32)
|
||||
else:
|
||||
batch_shape = None
|
||||
if event_shape.is_fully_defined():
|
||||
event_shape = constant_op.constant(event_shape.as_list(),
|
||||
dtype=dtypes.int32)
|
||||
else:
|
||||
event_shape = None
|
||||
if batch_shape is not None and event_shape is not None:
|
||||
return batch_shape, event_shape
|
||||
|
||||
# Collect known dynamic shape.
|
||||
if scale_oppd is not None:
|
||||
shape = scale_oppd.shape()
|
||||
elif shift is not None:
|
||||
shape = array_ops.shape(shift)
|
||||
else:
|
||||
raise ValueError("unable to infer batch_shape, event_shape")
|
||||
|
||||
# Fill in what we don't know.
|
||||
if batch_shape is None:
|
||||
batch_shape = array_ops.identity(shape[:-1], name="batch_shape")
|
||||
if event_shape is None:
|
||||
event_shape = array_ops.identity(shape[-1:], name="event_shape")
|
||||
|
||||
return batch_shape, event_shape
|
||||
|
||||
|
||||
class _VectorStudentT(transformed_distribution.TransformedDistribution):
|
||||
"""A vector version of Student's t-distribution on `R^k`.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
Write `S` for the scale matrix (in R^{k x k}) and `mu` for the mean (in R^k).
|
||||
The PDF of this distribution is:
|
||||
|
||||
```none
|
||||
f(x) = (1 + y y.T / df)**(-0.5 (df + 1)) / Z
|
||||
where,
|
||||
y(x) = inv(S) (x - mu)
|
||||
Z = abs(det(S)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k
|
||||
```
|
||||
|
||||
Notice that the matrix `S` has semantics more similar to standard deviation
|
||||
than covariance.
|
||||
|
||||
This distribution is an Affine transformation of iid
|
||||
[Student's t-distributions](
|
||||
https://en.wikipedia.org/wiki/Student%27s_t-distribution)
|
||||
and should not be confused with the [Multivate Student's t-distribution](
|
||||
https://en.wikipedia.org/wiki/Multivariate_t-distribution). The
|
||||
traditional Multivariate Student's t-distribution is type of
|
||||
[elliptical distribution](
|
||||
https://en.wikipedia.org/wiki/Elliptical_distribution); it has PDF:
|
||||
|
||||
```none
|
||||
f(x) = (1 + y y.T / df)**(-0.5 (df + k)) / Z
|
||||
where,
|
||||
y(x) = inv(S) (x - mu)
|
||||
Z = abs(det(S)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k))
|
||||
```
|
||||
|
||||
Notice that the Multivariate Student's t-distribution uses `k` where the
|
||||
Vector Student's t-distribution has a `1`. Conversely the Vector version has a
|
||||
broader application of the power-`k` in the normalization.
|
||||
|
||||
#### Examples
|
||||
|
||||
A single instance of a "Vector Student's t-distribution" is defined by a mean
|
||||
vector of of length `k` and a scale matrix of shape `k x k`.
|
||||
|
||||
Extra leading dimensions, if provided, allow for batches.
|
||||
|
||||
```python
|
||||
ds = tf.contrib.distributions
|
||||
|
||||
# Initialize a single 3-variate vector Student's t-distribution.
|
||||
mu = [1., 2, 3]
|
||||
chol = [[1., 0, 0.],
|
||||
[1, 3, 0],
|
||||
[1, 2, 3]]
|
||||
vt = ds.VectorStudentT(df=2, shift=mu, scale_tril=chol)
|
||||
|
||||
# Evaluate this on an observation in R^3, returning a scalar.
|
||||
vt.prob([-1., 0, 1])
|
||||
|
||||
# Initialize a batch of two 3-variate vector Student's t-distributions.
|
||||
mu = [[1., 2, 3],
|
||||
[11, 22, 33]]
|
||||
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
|
||||
vt = ds.VectorStudentT(shift=mu, scale_tril=chol)
|
||||
|
||||
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||
# tensor.
|
||||
x = [[-1, 0, 1],
|
||||
[-11, 0, 11]]
|
||||
vt.prob(x)
|
||||
```
|
||||
|
||||
For more examples of how to construct the `scale` matrix, see the
|
||||
`bijector.Affine` docstring.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
df,
|
||||
shift=None,
|
||||
scale_identity_multiplier=None,
|
||||
scale_diag=None,
|
||||
scale_tril=None,
|
||||
scale_perturb_factor=None,
|
||||
scale_perturb_diag=None,
|
||||
validate_args=False,
|
||||
allow_nan_stats=True,
|
||||
name="VectorStudentT"):
|
||||
"""Instantiates the vector Student's t-distributions on `R^k`.
|
||||
|
||||
The `batch_shape` is the broadcast between `df.batch_shape` and
|
||||
`Affine.batch_shape` where `Affine` is constructed from `shift` and
|
||||
`scale_*` arguments.
|
||||
|
||||
The `event_shape` is the event shape of `Affine.event_shape`.
|
||||
|
||||
Args:
|
||||
df: Numeric `Tensor`. The degrees of freedom of the distribution(s).
|
||||
`df` must contain only positive values.
|
||||
Must be scalar if `shift`, `scale_*` imply non-scalar batch_shape or
|
||||
must have the same `batch_shape` implied by `shift`, `scale_*`.
|
||||
shift: Numeric `Tensor`. If this is set to `None`, no `shift` is applied.
|
||||
scale_identity_multiplier: floating point rank 0 `Tensor` representing a
|
||||
scaling done to the identity matrix.
|
||||
When `scale_identity_multiplier = scale_diag=scale_tril = None` then
|
||||
`scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
|
||||
to `scale`.
|
||||
scale_diag: Numeric `Tensor` representing the diagonal matrix.
|
||||
`scale_diag` has shape [N1, N2, ... k], which represents a k x k
|
||||
diagonal matrix.
|
||||
When `None` no diagonal term is added to `scale`.
|
||||
scale_tril: Numeric `Tensor` representing the diagonal matrix.
|
||||
`scale_diag` has shape [N1, N2, ... k, k], which represents a k x k
|
||||
lower triangular matrix.
|
||||
When `None` no `scale_tril` term is added to `scale`.
|
||||
The upper triangular elements above the diagonal are ignored.
|
||||
scale_perturb_factor: Numeric `Tensor` representing factor matrix with
|
||||
last two dimensions of shape `(k, r)`.
|
||||
When `None`, no rank-r update is added to `scale`.
|
||||
scale_perturb_diag: Numeric `Tensor` representing the diagonal matrix.
|
||||
`scale_perturb_diag` has shape [N1, N2, ... r], which represents an
|
||||
r x r Diagonal matrix.
|
||||
When `None` low rank updates will take the form `scale_perturb_factor *
|
||||
scale_perturb_factor.T`.
|
||||
validate_args: `Boolean`, default `False`. Whether to validate input
|
||||
with asserts. If `validate_args` is `False`, and the inputs are
|
||||
invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
"""
|
||||
parameters = locals()
|
||||
parameters.pop("self")
|
||||
graph_parents = [df, shift, scale_identity_multiplier, scale_diag,
|
||||
scale_tril, scale_perturb_factor, scale_perturb_diag]
|
||||
with ops.name_scope(name) as ns:
|
||||
with ops.name_scope("init", values=graph_parents):
|
||||
# The shape of the _VectorStudentT distribution is governed by the
|
||||
# relationship between df.batch_shape and affine.batch_shape. In
|
||||
# pseudocode the basic procedure is:
|
||||
# if df.batch_shape is scalar:
|
||||
# if affine.batch_shape is not scalar:
|
||||
# # broadcast self._distribution.sample so
|
||||
# # it has affine.batch_shape.
|
||||
# self.batch_shape = affine.batch_shape
|
||||
# else:
|
||||
# if affine.batch_shape is scalar:
|
||||
# # let affine broadcasting do its thing.
|
||||
# self.batch_shape = df.batch_shape
|
||||
# All of the above magic is actually handled by TransformedDistribution.
|
||||
# Here we really only need to collect the affine.batch_shape and decide
|
||||
# what we're going to pass in to TransformedDistribution's
|
||||
# (override) batch_shape arg.
|
||||
self._distribution = student_t.StudentT(df=df, mu=0., sigma=1.)
|
||||
self._affine = bijectors.Affine(
|
||||
shift=shift,
|
||||
scale_identity_multiplier=scale_identity_multiplier,
|
||||
scale_diag=scale_diag,
|
||||
scale_tril=scale_tril,
|
||||
scale_perturb_factor=scale_perturb_factor,
|
||||
scale_perturb_diag=scale_perturb_diag,
|
||||
validate_args=validate_args)
|
||||
self._batch_shape, self._override_event_shape = _infer_shapes(
|
||||
self.scale, self.shift)
|
||||
self._override_batch_shape = distribution_util.pick_vector(
|
||||
self._distribution.is_scalar_batch(),
|
||||
self._batch_shape,
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
super(_VectorStudentT, self).__init__(
|
||||
distribution=self._distribution,
|
||||
bijector=self._affine,
|
||||
batch_shape=self._override_batch_shape,
|
||||
event_shape=self._override_event_shape,
|
||||
validate_args=validate_args,
|
||||
name=ns)
|
||||
|
||||
@property
|
||||
def df(self):
|
||||
"""Degrees of freedom in these Student's t distribution(s)."""
|
||||
return self._distribution.df
|
||||
|
||||
@property
|
||||
def shift(self):
|
||||
"""Locations of these Student's t distribution(s)."""
|
||||
return self._affine.shift
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
"""Dense (batch) covariance matrix, if available."""
|
||||
return self._affine.scale
|
@ -229,7 +229,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIn(keys_sparse, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
|
||||
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
|
||||
self.assertAllEqual(output[keys_sparse].indices.eval(),
|
||||
@ -247,7 +247,7 @@ class TransformerTest(test.TestCase):
|
||||
output = feature_column_ops._Transformer(features).transform(keys_sparse)
|
||||
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
# While the input is a dense Tensor, the output should be a SparseTensor.
|
||||
self.assertIsInstance(output, sparse_tensor.SparseTensor)
|
||||
self.assertEqual(output.dtype, dtypes.int64)
|
||||
@ -316,7 +316,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertIn(weighted_ids, output)
|
||||
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
|
||||
ids_tensor.dense_shape.eval())
|
||||
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
|
||||
@ -346,7 +346,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIn(vocab_sparse, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||
@ -368,7 +368,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIn(vocab_sparse, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||
@ -392,7 +392,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIn(vocab_sparse, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
|
||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||
@ -414,7 +414,7 @@ class TransformerTest(test.TestCase):
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIn(vocab_sparse, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
|
||||
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
|
||||
self.assertAllEqual(output[vocab_sparse].indices.eval(),
|
||||
@ -584,7 +584,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output.eval().shape, [3, 3 + 4 + 10])
|
||||
|
||||
def testRealValuedColumn(self):
|
||||
@ -681,7 +681,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
[one_hot_column])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
||||
output.eval())
|
||||
|
||||
@ -699,7 +699,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
||||
output.eval())
|
||||
|
||||
@ -717,7 +717,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||
output.eval())
|
||||
|
||||
@ -751,7 +751,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
[one_hot_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual([3, 10], output.eval().shape)
|
||||
|
||||
def testEmbeddingColumnSucceedsForDNN(self):
|
||||
@ -857,7 +857,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
[embeded_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||
|
||||
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
|
||||
@ -908,7 +908,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Error creating input layer for column: ids_weighted_by_weights"):
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
|
||||
|
||||
def testCrossedColumnFailsForDNN(self):
|
||||
@ -1015,7 +1015,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
[embeded_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
# score: (sum of weights)
|
||||
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
|
||||
|
||||
@ -1208,7 +1208,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input = sess.run(model_input_tensor)
|
||||
|
||||
expected_input_shape = np.array([4, 3, 4])
|
||||
@ -1242,7 +1242,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input = sess.run(model_input_tensor)
|
||||
|
||||
expected_input_shape = np.array([4, 3, hash_buckets])
|
||||
@ -1272,7 +1272,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input = sess.run(model_input_tensor)
|
||||
|
||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||
@ -1302,7 +1302,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
embedding_weights)
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
|
||||
|
||||
expected_input_shape = [4, 3, embedding_dimension]
|
||||
@ -1369,7 +1369,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input = sess.run(model_input_tensor)
|
||||
|
||||
expected_input_shape = [
|
||||
@ -1437,7 +1437,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [weighted_ids], num_outputs=5)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
|
||||
def testWeightedSparseColumnWithDenseInputTensor(self):
|
||||
@ -1453,7 +1453,7 @@ class WeightedSumTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
|
||||
def testCrossedColumn(self):
|
||||
@ -1507,7 +1507,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [movies], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.initialize_all_variables().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[movies][0]
|
||||
self.assertEqual(weights.get_shape(), (3, 1))
|
||||
@ -1582,7 +1582,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [age, language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||
|
||||
@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase):
|
||||
self.assertEqual(len(variables), 1)
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||
|
||||
@ -1686,7 +1686,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [weighted_language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllClose(output.eval(), [[0.], [0.]])
|
||||
|
||||
@ -1714,7 +1714,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
# score: 0.1 + language_weight['hindi'] + language_weight['english']
|
||||
sess.run(bias.assign([0.1]))
|
||||
@ -1737,7 +1737,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [movies], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[movies][0]
|
||||
self.assertEqual(weights.get_shape(), (15, 1))
|
||||
@ -1771,7 +1771,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country_language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[country_language][0]
|
||||
sess.run(weights.assign(weights + 0.4))
|
||||
@ -1795,7 +1795,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [language_language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[language_language][0]
|
||||
sess.run(weights.assign(weights + 0.4))
|
||||
@ -1828,7 +1828,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country_language], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[country_language][0]
|
||||
sess.run(weights.assign(weights + 0.4))
|
||||
@ -1869,7 +1869,7 @@ class WeightedSumTest(test.TestCase):
|
||||
scope=scope))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertEqual(2, len(column_to_variable[country]))
|
||||
self.assertEqual(3, len(column_to_variable[language]))
|
||||
@ -1906,7 +1906,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country, age, incomes], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
incomes_weights = column_to_variable[incomes][0]
|
||||
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
|
||||
@ -1943,7 +1943,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country, age, height, incomes], num_outputs=5))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
height_weights = column_to_variable[height][0]
|
||||
sess.run(
|
||||
@ -1973,7 +1973,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [bucket], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
|
||||
[0.4]]))
|
||||
@ -2001,7 +2001,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [bucket, country], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
# dimension = 2, bucket_size = 4, num_classes = 1
|
||||
sess.run(column_to_variable[bucket][0].assign(
|
||||
@ -2030,7 +2030,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [bucket, country], num_outputs=5))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
# dimension = 2, bucket_size = 4, num_classes = 5
|
||||
sess.run(column_to_variable[bucket][0].assign(
|
||||
@ -2066,7 +2066,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country_price], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[country_price][0]
|
||||
sess.run(weights.assign(weights + 0.4))
|
||||
@ -2105,7 +2105,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [country_language_price], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[country_language_price][0]
|
||||
sess.run(weights.assign(weights + 0.4))
|
||||
@ -2129,7 +2129,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [product], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
product_weights = column_to_variable[product][0]
|
||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||
@ -2144,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [product], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
product_weights = column_to_variable[product][0]
|
||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||
@ -2159,7 +2159,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [product], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
product_weights = column_to_variable[product][0]
|
||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||
self.assertAllClose(output.eval(), [[0.6], [0.7]])
|
||||
@ -2180,7 +2180,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [product], num_outputs=1))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
product_weights = column_to_variable[product][0]
|
||||
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
|
||||
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
|
||||
@ -2192,7 +2192,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [feature_column.real_valued_column("age")], num_outputs=3)
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
sess.run(bias.assign([0.1, 0.2, 0.3]))
|
||||
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
|
||||
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
|
||||
@ -2206,7 +2206,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [column], num_outputs=3))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
weights = column_to_variable[column][0]
|
||||
self.assertEqual(weights.get_shape(), (1, 3))
|
||||
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
|
||||
@ -2230,7 +2230,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [column], num_outputs=3))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
weights = column_to_variable[column][0]
|
||||
self.assertEqual(weights.get_shape(), (5, 3))
|
||||
sess.run(
|
||||
@ -2256,7 +2256,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [column], num_outputs=3))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[column][0]
|
||||
self.assertEqual(weights.get_shape(), (5, 3))
|
||||
@ -2296,7 +2296,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [column], num_outputs=3))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[column][0]
|
||||
self.assertEqual(weights.get_shape(), (5, 3))
|
||||
@ -2325,7 +2325,7 @@ class WeightedSumTest(test.TestCase):
|
||||
features, [column], num_outputs=3))
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
weights = column_to_variable[column][0]
|
||||
self.assertEqual(weights.get_shape(), (5, 3))
|
||||
@ -2390,7 +2390,7 @@ class ParseExampleTest(test.TestCase):
|
||||
self.assertIn(bucket, output)
|
||||
self.assertIn(wire_cast, output)
|
||||
with self.test_session():
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
|
||||
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
|
||||
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
|
||||
|
@ -160,7 +160,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
||||
self.context_feature_columns)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(data_flow_ops.initialize_all_tables())
|
||||
sess.run(data_flow_ops.tables_initializer())
|
||||
sequence_input_val = sess.run(sequence_input)
|
||||
expected_shape = np.array([
|
||||
3, # expected batch size
|
||||
@ -181,7 +181,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
||||
# Obtain values of activations and final state.
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(data_flow_ops.initialize_all_tables())
|
||||
sess.run(data_flow_ops.tables_initializer())
|
||||
activations, final_state = sess.run([activations_t, final_state_t])
|
||||
|
||||
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
|
||||
|
@ -1283,7 +1283,7 @@ class Estimator(BaseEstimator):
|
||||
|
||||
with tf_session.Session('') as session:
|
||||
variables.initialize_local_variables()
|
||||
data_flow_ops.initialize_all_tables()
|
||||
data_flow_ops.tables_initializer()
|
||||
saver_for_restore = saver.Saver(
|
||||
variables.global_variables(),
|
||||
sharded=True)
|
||||
@ -1291,7 +1291,7 @@ class Estimator(BaseEstimator):
|
||||
|
||||
init_op = control_flow_ops.group(
|
||||
variables.local_variables_initializer(),
|
||||
data_flow_ops.initialize_all_tables())
|
||||
data_flow_ops.tables_initializer())
|
||||
|
||||
# Perform the export
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
@ -22,8 +22,6 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.factorization.python.ops import clustering_ops
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps
|
||||
from tensorflow.python.framework import ops
|
||||
@ -42,7 +40,7 @@ KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT
|
||||
|
||||
|
||||
# TODO(agarwal,ands): support sharded input.
|
||||
class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
class KMeansClustering(estimator.Estimator):
|
||||
"""An Estimator fo rK-Means clustering."""
|
||||
SCORES = 'scores'
|
||||
CLUSTER_IDX = 'cluster_idx'
|
||||
@ -58,6 +56,7 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
random_seed=0,
|
||||
use_mini_batch=True,
|
||||
kmeans_plus_plus_num_retries=2,
|
||||
relative_tolerance=None,
|
||||
config=None):
|
||||
"""Creates a model for running KMeans training and inference.
|
||||
|
||||
@ -76,6 +75,9 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
additional points to draw from the current distribution before selecting
|
||||
the best. If a negative value is specified, a heuristic is used to
|
||||
sample O(log(num_to_sample)) additional points.
|
||||
relative_tolerance: A relative tolerance of change in the loss between
|
||||
iterations. Stops learning if the loss changes less than this amount.
|
||||
Note that this may not work correctly if use_mini_batch=True.
|
||||
config: See Estimator
|
||||
"""
|
||||
self._num_clusters = num_clusters
|
||||
@ -84,7 +86,8 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
self._random_seed = random_seed
|
||||
self._use_mini_batch = use_mini_batch
|
||||
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
|
||||
self._estimator = estimator.Estimator(
|
||||
self._relative_tolerance = relative_tolerance
|
||||
super(KMeansClustering, self).__init__(
|
||||
model_fn=self._get_model_function(), model_dir=model_dir)
|
||||
|
||||
class LossRelativeChangeHook(session_run_hook.SessionRunHook):
|
||||
@ -119,76 +122,13 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
run_context.request_stop()
|
||||
self._prev_loss = loss
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
"""See Evaluable."""
|
||||
return self._estimator.model_dir
|
||||
|
||||
def fit(self,
|
||||
input_fn=None,
|
||||
steps=None,
|
||||
monitors=None,
|
||||
max_steps=None,
|
||||
relative_tolerance=None):
|
||||
"""Trains a k-means clustering on x.
|
||||
|
||||
Note: See Estimator for logic for continuous training and graph
|
||||
construction across multiple calls to fit.
|
||||
|
||||
Args:
|
||||
input_fn: see Trainable.fit.
|
||||
steps: see Trainable.fit.
|
||||
monitors: see Trainable.fit.
|
||||
max_steps: see Trainable.fit.
|
||||
relative_tolerance: A relative tolerance of change in the loss between
|
||||
iterations. Stops learning if the loss changes less than this amount.
|
||||
Note that this may not work correctly if use_mini_batch=True.
|
||||
|
||||
Returns:
|
||||
Returns self.
|
||||
"""
|
||||
if relative_tolerance is not None:
|
||||
if monitors is None:
|
||||
monitors = []
|
||||
monitors.append(self.LossRelativeChangeHook(relative_tolerance))
|
||||
# Make sure that we will eventually terminate.
|
||||
assert ((monitors is not None and len(monitors)) or (steps is not None) or
|
||||
(max_steps is not None))
|
||||
self._estimator.fit(input_fn=input_fn,
|
||||
steps=steps,
|
||||
max_steps=max_steps,
|
||||
monitors=monitors)
|
||||
return self
|
||||
|
||||
def evaluate(self,
|
||||
input_fn=None,
|
||||
feed_fn=None,
|
||||
steps=None,
|
||||
metrics=None,
|
||||
name=None,
|
||||
checkpoint_path=None,
|
||||
hooks=None):
|
||||
"""See Evaluable.evaluate."""
|
||||
return self._estimator.evaluate(
|
||||
input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
metrics=metrics,
|
||||
name=name,
|
||||
checkpoint_path=checkpoint_path,
|
||||
hooks=hooks)
|
||||
|
||||
def predict(self, input_fn=None, outputs=None, as_iterable=False):
|
||||
"""See BaseEstimator.predict."""
|
||||
|
||||
outputs = outputs or [KMeansClustering.CLUSTER_IDX]
|
||||
assert isinstance(outputs, list)
|
||||
results = self._estimator.predict(
|
||||
input_fn=input_fn, outputs=outputs, as_iterable=as_iterable)
|
||||
if len(outputs) == 1 and not as_iterable:
|
||||
return results[outputs[0]]
|
||||
else:
|
||||
return results
|
||||
def predict_cluster_idx(self, input_fn=None):
|
||||
"""Yields predicted cluster indices."""
|
||||
key = KMeansClustering.CLUSTER_IDX
|
||||
results = super(KMeansClustering, self).predict(
|
||||
input_fn=input_fn, outputs=[key])
|
||||
for result in results:
|
||||
yield result[key]
|
||||
|
||||
def score(self, input_fn=None, steps=None):
|
||||
"""Predict total sum of distances to nearest clusters.
|
||||
@ -223,14 +163,19 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
Array with same number of rows as x, and num_clusters columns, containing
|
||||
distances to the cluster centers.
|
||||
"""
|
||||
return self.predict(
|
||||
key = KMeansClustering.ALL_SCORES
|
||||
results = super(KMeansClustering, self).predict(
|
||||
input_fn=input_fn,
|
||||
outputs=[KMeansClustering.ALL_SCORES],
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if not as_iterable:
|
||||
return results[key]
|
||||
else:
|
||||
return results
|
||||
|
||||
def clusters(self):
|
||||
"""Returns cluster centers."""
|
||||
return self._estimator.get_variable_value(self.CLUSTERS)
|
||||
return super(KMeansClustering, self).get_variable_value(self.CLUSTERS)
|
||||
|
||||
def _parse_tensor_or_dict(self, features):
|
||||
if isinstance(features, dict):
|
||||
@ -264,11 +209,16 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
|
||||
KMeansClustering.CLUSTER_IDX: model_predictions[0],
|
||||
}
|
||||
eval_metric_ops = {KMeansClustering.SCORES: loss,}
|
||||
if self._relative_tolerance is not None:
|
||||
training_hooks = [self.LossRelativeChangeHook(self._relative_tolerance)]
|
||||
else:
|
||||
training_hooks = None
|
||||
return ModelFnOps(
|
||||
mode=mode,
|
||||
predictions=predictions,
|
||||
eval_metric_ops=eval_metric_ops,
|
||||
loss=loss,
|
||||
train_op=training_op)
|
||||
train_op=training_op,
|
||||
training_hooks=training_hooks)
|
||||
|
||||
return _model_fn
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import input as input_lib
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -68,7 +69,7 @@ def make_random_points(centers, num_points, max_offset=20):
|
||||
|
||||
class KMeansTestBase(test.TestCase):
|
||||
|
||||
def input_fn(self, batch_size=None, points=None):
|
||||
def input_fn(self, batch_size=None, points=None, num_epochs=None):
|
||||
"""Returns an input_fn that randomly selects batches from given points."""
|
||||
batch_size = batch_size or self.batch_size
|
||||
points = points if points is not None else self.points
|
||||
@ -77,14 +78,15 @@ class KMeansTestBase(test.TestCase):
|
||||
def _fn():
|
||||
x = constant_op.constant(points)
|
||||
if batch_size == num_points:
|
||||
return x, None
|
||||
return input_lib.limit_epochs(x, num_epochs=num_epochs), None
|
||||
indices = random_ops.random_uniform(
|
||||
constant_op.constant([batch_size]),
|
||||
minval=0,
|
||||
maxval=num_points - 1,
|
||||
dtype=dtypes.int32,
|
||||
seed=10)
|
||||
return array_ops.gather(x, indices), None
|
||||
batch = array_ops.gather(x, indices)
|
||||
return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None)
|
||||
|
||||
return _fn
|
||||
|
||||
@ -113,21 +115,23 @@ class KMeansTest(KMeansTestBase):
|
||||
self.num_points)
|
||||
self.true_score = np.add.reduce(self.scores)
|
||||
|
||||
self.kmeans = kmeans_lib.KMeansClustering(
|
||||
def _kmeans(self, relative_tolerance=None):
|
||||
return kmeans_lib.KMeansClustering(
|
||||
self.num_centers,
|
||||
initial_clusters=factorization.RANDOM_INIT,
|
||||
use_mini_batch=self.use_mini_batch,
|
||||
config=self.config(14),
|
||||
random_seed=10)
|
||||
random_seed=10,
|
||||
relative_tolerance=relative_tolerance)
|
||||
|
||||
def test_clusters(self):
|
||||
kmeans = self.kmeans
|
||||
kmeans = self._kmeans()
|
||||
kmeans.fit(input_fn=self.input_fn(), steps=1)
|
||||
clusters = kmeans.clusters()
|
||||
self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims])
|
||||
|
||||
def test_fit(self):
|
||||
kmeans = self.kmeans
|
||||
kmeans = self._kmeans()
|
||||
kmeans.fit(input_fn=self.input_fn(), steps=1)
|
||||
score1 = kmeans.score(
|
||||
input_fn=self.input_fn(batch_size=self.num_points), steps=1)
|
||||
@ -146,20 +150,20 @@ class KMeansTest(KMeansTestBase):
|
||||
initial_clusters=factorization.RANDOM_INIT,
|
||||
use_mini_batch=self.use_mini_batch,
|
||||
config=run_config.RunConfig(tf_random_seed=14),
|
||||
random_seed=12)
|
||||
random_seed=12,
|
||||
relative_tolerance=1e-4)
|
||||
|
||||
kmeans.fit(
|
||||
input_fn=self.input_fn(),
|
||||
# Force it to train forever until the monitor stops it.
|
||||
steps=None,
|
||||
relative_tolerance=1e-4)
|
||||
# Force it to train until the relative tolerance monitor stops it.
|
||||
steps=None)
|
||||
score = kmeans.score(
|
||||
input_fn=self.input_fn(batch_size=self.num_points), steps=1)
|
||||
self.assertNear(self.true_score, score, self.true_score * 0.005)
|
||||
|
||||
def test_infer(self):
|
||||
kmeans = self.kmeans
|
||||
kmeans.fit(input_fn=self.input_fn(), relative_tolerance=1e-4)
|
||||
kmeans = self._kmeans(relative_tolerance=1e-4)
|
||||
kmeans.fit(input_fn=self.input_fn())
|
||||
clusters = kmeans.clusters()
|
||||
|
||||
# Make a small test set
|
||||
@ -167,8 +171,8 @@ class KMeansTest(KMeansTestBase):
|
||||
points, true_assignments, true_offsets = make_random_points(clusters,
|
||||
num_points)
|
||||
# Test predict
|
||||
assignments = kmeans.predict(input_fn=self.input_fn(
|
||||
batch_size=num_points, points=points))
|
||||
assignments = list(kmeans.predict_cluster_idx(input_fn=self.input_fn(
|
||||
batch_size=num_points, points=points, num_epochs=1)))
|
||||
self.assertAllEqual(assignments, true_assignments)
|
||||
|
||||
# Test score
|
||||
@ -260,7 +264,8 @@ class KMeansTestCosineDistance(KMeansTestBase):
|
||||
centers, axis=0), np.sort(
|
||||
self.true_centers, axis=0), atol=1e-2)
|
||||
|
||||
assignments = self.kmeans.predict(input_fn=self.input_fn())
|
||||
assignments = list(self.kmeans.predict_cluster_idx(
|
||||
input_fn=self.input_fn(num_epochs=1)))
|
||||
self.assertAllClose(
|
||||
centers[assignments],
|
||||
self.true_centers[self.true_assignments],
|
||||
@ -305,8 +310,11 @@ class KMeansTestCosineDistance(KMeansTestBase):
|
||||
self.assertAllClose(
|
||||
sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2)
|
||||
|
||||
assignments = kmeans.predict(
|
||||
input_fn=lambda: (constant_op.constant(points), None))
|
||||
def _input_fn():
|
||||
return (
|
||||
input_lib.limit_epochs(constant_op.constant(points), num_epochs=1),
|
||||
None)
|
||||
assignments = list(kmeans.predict_cluster_idx(input_fn=_input_fn))
|
||||
self.assertAllClose(
|
||||
centers[assignments], true_centers[true_assignments], atol=1e-2)
|
||||
|
||||
|
@ -25,3 +25,4 @@ class PredictionKey(object):
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
||||
TOP_K = "top_k"
|
||||
GENERIC = "output"
|
||||
|
@ -634,7 +634,7 @@ def _get_local_init_op():
|
||||
ops.GraphKeys.LOCAL_INIT_OP)
|
||||
if local_init_op is None:
|
||||
op_list = [variables.local_variables_initializer(),
|
||||
data_flow_ops.initialize_all_tables()]
|
||||
data_flow_ops.tables_initializer()]
|
||||
if op_list:
|
||||
local_init_op = control_flow_ops.group(*op_list)
|
||||
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
|
||||
@ -881,7 +881,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
||||
else:
|
||||
session.run(variables.global_variables_initializer())
|
||||
session.run(variables.local_variables_initializer())
|
||||
session.run(data_flow_ops.initialize_all_tables())
|
||||
session.run(data_flow_ops.tables_initializer())
|
||||
coord = coordinator.Coordinator()
|
||||
threads = None
|
||||
try:
|
||||
|
@ -156,12 +156,12 @@ def read_keyed_batch_examples(file_pattern,
|
||||
file_pattern,
|
||||
batch_size,
|
||||
reader,
|
||||
randomize_input,
|
||||
num_epochs,
|
||||
queue_capacity,
|
||||
num_threads,
|
||||
read_batch_size,
|
||||
parse_fn,
|
||||
randomize_input=randomize_input,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
num_threads=num_threads,
|
||||
read_batch_size=read_batch_size,
|
||||
parse_fn=parse_fn,
|
||||
setup_shared_queue=False,
|
||||
name=name)
|
||||
|
||||
@ -225,12 +225,12 @@ def _read_keyed_batch_examples_shared_queue(file_pattern,
|
||||
file_pattern,
|
||||
batch_size,
|
||||
reader,
|
||||
randomize_input,
|
||||
num_epochs,
|
||||
queue_capacity,
|
||||
num_threads,
|
||||
read_batch_size,
|
||||
parse_fn,
|
||||
randomize_input=randomize_input,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
num_threads=num_threads,
|
||||
read_batch_size=read_batch_size,
|
||||
parse_fn=parse_fn,
|
||||
setup_shared_queue=True,
|
||||
name=name)
|
||||
|
||||
@ -265,7 +265,7 @@ def _get_file_names(file_pattern, randomize_input):
|
||||
|
||||
|
||||
def _get_examples(file_name_queue, reader, num_threads, read_batch_size,
|
||||
parse_fn):
|
||||
filter_fn, parse_fn):
|
||||
with ops.name_scope('read'):
|
||||
example_list = []
|
||||
for _ in range(num_threads):
|
||||
@ -274,6 +274,10 @@ def _get_examples(file_name_queue, reader, num_threads, read_batch_size,
|
||||
read_batch_size)
|
||||
else:
|
||||
keys, examples_proto = reader().read(file_name_queue)
|
||||
if filter_fn:
|
||||
mask = filter_fn(keys, examples_proto)
|
||||
keys = array_ops.boolean_mask(keys, mask)
|
||||
examples_proto = array_ops.boolean_mask(examples_proto, mask)
|
||||
if parse_fn:
|
||||
parsed_examples = parse_fn(examples_proto)
|
||||
# Map keys into example map because batch_join doesn't support
|
||||
@ -296,6 +300,7 @@ def _read_keyed_batch_examples_helper(file_pattern,
|
||||
queue_capacity=10000,
|
||||
num_threads=1,
|
||||
read_batch_size=1,
|
||||
filter_fn=None,
|
||||
parse_fn=None,
|
||||
setup_shared_queue=False,
|
||||
name=None):
|
||||
@ -316,6 +321,9 @@ def _read_keyed_batch_examples_helper(file_pattern,
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
read_batch_size: An int or scalar `Tensor` specifying the number of
|
||||
records to read at once
|
||||
filter_fn: Filtering function, takes both keys as well `Example` Tensors
|
||||
and returns a boolean mask of the same shape as the input Tensors to
|
||||
be applied for filtering. If `None`, no filtering is done.
|
||||
parse_fn: Parsing function, takes `Example` Tensor returns parsed
|
||||
representation. If `None`, no parsing is done.
|
||||
setup_shared_queue: Whether to set up a shared queue for file names.
|
||||
@ -366,7 +374,7 @@ def _read_keyed_batch_examples_helper(file_pattern,
|
||||
name=file_name_queue_scope)
|
||||
|
||||
example_list = _get_examples(file_name_queue, reader, num_threads,
|
||||
read_batch_size, parse_fn)
|
||||
read_batch_size, filter_fn, parse_fn)
|
||||
|
||||
enqueue_many = read_batch_size > 1
|
||||
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -681,6 +682,67 @@ class GraphIOTest(test.TestCase):
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_keyed_features_filter(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
lines = [
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}',
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}',
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}',
|
||||
'{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}'
|
||||
]
|
||||
filename = self._create_temp_file("\n".join(lines))
|
||||
|
||||
batch_size = 2
|
||||
queue_capacity = 4
|
||||
name = "my_batch"
|
||||
features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)}
|
||||
|
||||
def filter_fn(keys, examples_json):
|
||||
del keys
|
||||
serialized = parsing_ops.decode_json_example(examples_json)
|
||||
examples = parsing_ops.parse_example(serialized, features)
|
||||
return math_ops.less(examples["age"], 2)
|
||||
|
||||
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
|
||||
keys, inputs = graph_io._read_keyed_batch_examples_helper(
|
||||
filename,
|
||||
batch_size,
|
||||
reader=io_ops.TextLineReader,
|
||||
randomize_input=False,
|
||||
num_epochs=1,
|
||||
read_batch_size=batch_size,
|
||||
queue_capacity=queue_capacity,
|
||||
filter_fn=filter_fn,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(variables.local_variables_initializer())
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
|
||||
# First batch of two filtered examples.
|
||||
out_keys, out_vals = session.run((keys, inputs))
|
||||
self.assertAllEqual(
|
||||
[filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"],
|
||||
out_keys)
|
||||
self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")],
|
||||
out_vals)
|
||||
|
||||
# Second batch will only have one filtered example as that's the only
|
||||
# remaining example that satisfies the filtering criterion.
|
||||
out_keys, out_vals = session.run((keys, inputs))
|
||||
self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys)
|
||||
self.assertAllEqual([lines[3].encode("utf-8")], out_vals)
|
||||
|
||||
# Exhausted input.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
session.run((keys, inputs))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -88,20 +88,20 @@ def run(experiment_fn, output_dir, schedule=None):
|
||||
# Execute the schedule
|
||||
if not hasattr(experiment, schedule):
|
||||
logging.error('Schedule references non-existent task %s', schedule)
|
||||
valid_tasks = [x for x in experiment.__dict__
|
||||
if callable(getattr(experiment, x))]
|
||||
valid_tasks = [x for x in dir(experiment)
|
||||
if not x.startswith('_')
|
||||
and callable(getattr(experiment, x))]
|
||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||
raise ValueError('Schedule references non-existent task %s', schedule)
|
||||
raise ValueError('Schedule references non-existent task %s' % schedule)
|
||||
|
||||
task = getattr(experiment, schedule)
|
||||
if not callable(task):
|
||||
logging.error('Schedule references non-callable member %s', schedule)
|
||||
valid_tasks = [
|
||||
x for x in experiment.__dict__
|
||||
if callable(getattr(experiment, x)) and not x.startswith('_')
|
||||
]
|
||||
valid_tasks = [x for x in dir(experiment)
|
||||
if not x.startswith('_')
|
||||
and callable(getattr(experiment, x))]
|
||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||
raise TypeError('Schedule references non-callable member %s', schedule)
|
||||
raise TypeError('Schedule references non-callable member %s' % schedule)
|
||||
|
||||
return task()
|
||||
|
||||
|
@ -27,9 +27,11 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.contrib.learn.python.learn import experiment
|
||||
from tensorflow.contrib.learn.python.learn import learn_runner
|
||||
from tensorflow.contrib.learn.python.learn import run_config
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
@ -43,13 +45,22 @@ class TestExperiment(experiment.Experiment):
|
||||
self.default = default
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def estimator(self):
|
||||
|
||||
class Estimator(object):
|
||||
class Estimator(evaluable.Evaluable, trainable.Trainable):
|
||||
config = self.config
|
||||
|
||||
return Estimator()
|
||||
def model_dir(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||
monitors=None, max_steps=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||
batch_size=None, steps=None, metrics=None, name=None,
|
||||
checkpoint_path=None, hooks=None):
|
||||
raise NotImplementedError
|
||||
|
||||
super(TestExperiment, self).__init__(Estimator(), None, None)
|
||||
|
||||
def local_run(self):
|
||||
return "local_run"
|
||||
|
@ -66,13 +66,13 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
|
||||
with graph.as_default():
|
||||
with tf_session.Session('') as session:
|
||||
variables.local_variables_initializer()
|
||||
data_flow_ops.initialize_all_tables()
|
||||
data_flow_ops.tables_initializer()
|
||||
saver.restore(session, checkpoint_path)
|
||||
|
||||
export = exporter.Exporter(saver)
|
||||
export.init(init_op=control_flow_ops.group(
|
||||
variables.local_variables_initializer(),
|
||||
data_flow_ops.initialize_all_tables()),
|
||||
data_flow_ops.tables_initializer()),
|
||||
default_graph_signature=default_graph_signature,
|
||||
named_graph_signatures=named_graph_signatures,
|
||||
assets_collection=ops.get_collection(
|
||||
|
@ -29,11 +29,12 @@ from tensorflow.python.ops import parsing_ops
|
||||
# A return type allowing input_fns to return multiple values in a well-
|
||||
# defined way (analogous to ModelFnOps).
|
||||
# The expected return values are:
|
||||
# features: a dict of string to Tensor, giving the features to be passed to
|
||||
# the model.
|
||||
# labels: a dict of string to Tensor, giving labels (aka targets) for training.
|
||||
# default_inputs: a dict of string to Tensor, giving the input Tensors (if
|
||||
# any) that this input_fn expects to be fed.
|
||||
# features: a dict of string to `Tensor` or `SparseTensor`, giving the features
|
||||
# to be passed to the model.
|
||||
# labels: a dict of string to `Tensor` or `SparseTensor`, giving labels (aka
|
||||
# targets) for training.
|
||||
# default_inputs: a dict of string to `Tensor` or `SparseTensor`, giving the
|
||||
# input placeholders (if any) that this input_fn expects to be fed.
|
||||
InputFnOps = collections.namedtuple('InputFnOps',
|
||||
['features',
|
||||
'labels',
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@ -128,10 +127,15 @@ def get_input_alternatives(input_ops):
|
||||
if not features:
|
||||
raise ValueError('Features must be defined.')
|
||||
|
||||
# TODO(b/34253951): reinstate the "features" input_signature.
|
||||
# The "features" input_signature, as written, does not work with
|
||||
# SparseTensors. It is simply commented out as a stopgap, pending discussion
|
||||
# on the bug as to the correct solution.
|
||||
|
||||
# Add the "features" input_signature in any case.
|
||||
# Note defensive copy because model_fns alter the features dict.
|
||||
input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = (
|
||||
copy.copy(features))
|
||||
# input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = (
|
||||
# copy.copy(features))
|
||||
|
||||
return input_alternatives, features
|
||||
|
||||
@ -163,6 +167,8 @@ def get_output_alternatives(
|
||||
# interpret the model as single-headed of unknown type.
|
||||
default_problem_type = constants.ProblemType.UNSPECIFIED
|
||||
default_outputs = model_fn_ops.predictions
|
||||
if not isinstance(default_outputs, dict):
|
||||
default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
|
||||
actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
|
||||
output_alternatives = {actual_default_output_alternative_key:
|
||||
(default_problem_type, default_outputs)}
|
||||
@ -182,9 +188,10 @@ def build_all_signature_defs(input_alternatives, output_alternatives,
|
||||
in output_alternatives.items()}
|
||||
|
||||
# Add the default SignatureDef
|
||||
default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY]
|
||||
default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY)
|
||||
if not default_inputs:
|
||||
default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY]
|
||||
raise ValueError('A default input_alternative must be provided.')
|
||||
# default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY]
|
||||
# default outputs are guaranteed to exist above
|
||||
(default_problem_type, default_outputs) = (
|
||||
output_alternatives[actual_default_output_alternative_key])
|
||||
@ -252,7 +259,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep):
|
||||
def make_export_strategy(export_input_fn,
|
||||
default_output_alternative_key='default',
|
||||
assets_extra=None,
|
||||
export_as_text=False,
|
||||
as_text=False,
|
||||
exports_to_keep=None):
|
||||
"""Create an ExportStrategy for use with Experiment."""
|
||||
|
||||
@ -263,7 +270,7 @@ def make_export_strategy(export_input_fn,
|
||||
export_input_fn,
|
||||
default_output_alternative_key=default_output_alternative_key,
|
||||
assets_extra=assets_extra,
|
||||
export_as_text=export_as_text,
|
||||
as_text=as_text,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
garbage_collect_exports(export_dir_base, exports_to_keep)
|
||||
|
@ -22,11 +22,14 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
# TODO: #6568 Remove this hack that makes dlopen() not crash.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
|
||||
# TODO(jart): #6568 Remove this hack that makes dlopen() not crash.
|
||||
if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
@ -87,9 +90,9 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
self.assertEqual(input_alternatives[
|
||||
saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY],
|
||||
"bogus default input dict")
|
||||
self.assertEqual(input_alternatives[
|
||||
saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY],
|
||||
"bogus features dict")
|
||||
# self.assertEqual(input_alternatives[
|
||||
# saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY],
|
||||
# "bogus features dict")
|
||||
|
||||
def test_get_output_alternatives_explicit(self):
|
||||
provided_output_alternatives = {
|
||||
@ -122,6 +125,21 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
})
|
||||
}, output_alternatives)
|
||||
|
||||
def test_get_output_alternatives_implicit_single(self):
|
||||
prediction_tensor = constant_op.constant(["bogus"])
|
||||
model_fn_ops = model_fn.ModelFnOps(
|
||||
model_fn.ModeKeys.INFER,
|
||||
predictions=prediction_tensor,
|
||||
output_alternatives=None)
|
||||
|
||||
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
|
||||
model_fn_ops)
|
||||
self.assertEqual({
|
||||
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||
"output": prediction_tensor
|
||||
})
|
||||
}, output_alternatives)
|
||||
|
||||
def test_build_all_signature_defs(self):
|
||||
input_features = constant_op.constant(["10"])
|
||||
input_example = constant_op.constant(["11"])
|
||||
@ -168,20 +186,56 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
signature_def_utils.predict_signature_def({
|
||||
"input": input_example
|
||||
}, {"output": output_3}),
|
||||
"features_input_alternative:head-1":
|
||||
signature_def_utils.regression_signature_def(input_features,
|
||||
output_1),
|
||||
"features_input_alternative:head-2":
|
||||
signature_def_utils.classification_signature_def(input_features,
|
||||
output_2, None),
|
||||
"features_input_alternative:head-3":
|
||||
signature_def_utils.predict_signature_def({
|
||||
"input": input_features
|
||||
}, {"output": output_3}),
|
||||
# "features_input_alternative:head-1":
|
||||
# signature_def_utils.regression_signature_def(input_features,
|
||||
# output_1),
|
||||
# "features_input_alternative:head-2":
|
||||
# signature_def_utils.classification_signature_def(input_features,
|
||||
# output_2, None),
|
||||
# "features_input_alternative:head-3":
|
||||
# signature_def_utils.predict_signature_def({
|
||||
# "input": input_features
|
||||
# }, {"output": output_3}),
|
||||
}
|
||||
|
||||
self.assertDictEqual(expected_signature_defs, signature_defs)
|
||||
|
||||
def test_build_all_signature_defs_legacy_input_fn_not_supported(self):
|
||||
"""Tests that legacy input_fn returning (features, labels) raises error.
|
||||
|
||||
serving_input_fn must return InputFnOps including a default input
|
||||
alternative.
|
||||
"""
|
||||
input_features = constant_op.constant(["10"])
|
||||
input_ops = ({"features": input_features}, None)
|
||||
input_alternatives, _ = (
|
||||
saved_model_export_utils.get_input_alternatives(input_ops))
|
||||
output_1 = constant_op.constant(["1"])
|
||||
output_2 = constant_op.constant(["2"])
|
||||
output_3 = constant_op.constant(["3"])
|
||||
provided_output_alternatives = {
|
||||
"head-1": (constants.ProblemType.LINEAR_REGRESSION, {
|
||||
"some_output_1": output_1
|
||||
}),
|
||||
"head-2": (constants.ProblemType.CLASSIFICATION, {
|
||||
"some_output_2": output_2
|
||||
}),
|
||||
"head-3": (constants.ProblemType.UNSPECIFIED, {
|
||||
"some_output_3": output_3
|
||||
}),
|
||||
}
|
||||
model_fn_ops = model_fn.ModelFnOps(
|
||||
model_fn.ModeKeys.INFER,
|
||||
predictions={"some_output": constant_op.constant(["4"])},
|
||||
output_alternatives=provided_output_alternatives)
|
||||
output_alternatives, _ = (saved_model_export_utils.get_output_alternatives(
|
||||
model_fn_ops, "head-1"))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "A default input_alternative must be provided"):
|
||||
saved_model_export_utils.build_all_signature_defs(
|
||||
input_alternatives, output_alternatives, "head-1")
|
||||
|
||||
def test_get_timestamped_export_dir(self):
|
||||
export_dir_base = tempfile.mkdtemp() + "export/"
|
||||
export_dir_1 = saved_model_export_utils.get_timestamped_export_dir(
|
||||
@ -227,6 +281,19 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
self.assertTrue(gfile.Exists(export_dir_3))
|
||||
self.assertTrue(gfile.Exists(export_dir_4))
|
||||
|
||||
def test_make_export_strategy(self):
|
||||
"""Only tests that an ExportStrategy instance is created."""
|
||||
def _export_input_fn():
|
||||
return array_ops.constant([1]), None
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
export_input_fn=_export_input_fn,
|
||||
default_output_alternative_key="default",
|
||||
assets_extra={"from/path": "to/path"},
|
||||
as_text=False,
|
||||
exports_to_keep=5)
|
||||
self.assertTrue(
|
||||
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
|
||||
|
||||
|
||||
def _create_test_export_dir(export_dir_base):
|
||||
export_dir = saved_model_export_utils.get_timestamped_export_dir(
|
||||
|
@ -31,6 +31,7 @@ Subclasses of `LinearOperator` provide a access to common methods on a
|
||||
|
||||
@@LinearOperatorDiag
|
||||
@@LinearOperatorIdentity
|
||||
@@LinearOperatorScaledIdentity
|
||||
@@LinearOperatorMatrix
|
||||
@@LinearOperatorTriL
|
||||
|
||||
|
@ -33,7 +33,7 @@ random_seed.set_random_seed(23)
|
||||
rng = np.random.RandomState(2016)
|
||||
|
||||
|
||||
class LinearOperatorIdentitytest(
|
||||
class LinearOperatorIdentityTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
|
||||
@ -251,5 +251,184 @@ class LinearOperatorIdentitytest(
|
||||
)
|
||||
|
||||
|
||||
class LinearOperatorScaledIdentityTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
|
||||
@property
|
||||
def _dtypes_to_test(self):
|
||||
# TODO(langmore) Test tf.float16 once tf.matrix_solve works in
|
||||
# 16bit.
|
||||
return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
|
||||
|
||||
def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
|
||||
shape = list(shape)
|
||||
assert shape[-1] == shape[-2]
|
||||
|
||||
batch_shape = shape[:-2]
|
||||
num_rows = shape[-1]
|
||||
|
||||
# Uniform values that are at least length 1 from the origin. Allows the
|
||||
# operator to be well conditioned.
|
||||
# Shape batch_shape
|
||||
multiplier = linear_operator_test_util.random_sign_uniform(
|
||||
shape=batch_shape, minval=1., maxval=2., dtype=dtype)
|
||||
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier)
|
||||
|
||||
# Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args.
|
||||
if use_placeholder:
|
||||
multiplier_ph = array_ops.placeholder(dtype=dtype)
|
||||
multiplier = multiplier.eval()
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows, multiplier_ph)
|
||||
feed_dict = {multiplier_ph: multiplier}
|
||||
else:
|
||||
feed_dict = None
|
||||
|
||||
multiplier_matrix = array_ops.expand_dims(
|
||||
array_ops.expand_dims(multiplier, -1), -1)
|
||||
mat = multiplier_matrix * linalg_ops.eye(
|
||||
num_rows, batch_shape=batch_shape, dtype=dtype)
|
||||
|
||||
return operator, mat, feed_dict
|
||||
|
||||
def test_assert_positive_definite_does_not_raise_when_positive(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=1.)
|
||||
operator.assert_positive_definite().run() # Should not fail
|
||||
|
||||
def test_assert_positive_definite_raises_when_negative(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=-1.)
|
||||
with self.assertRaisesOpError("not positive definite"):
|
||||
operator.assert_positive_definite().run()
|
||||
|
||||
def test_assert_non_singular_does_not_raise_when_non_singular(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1., 2., 3.])
|
||||
operator.assert_non_singular().run() # Should not fail
|
||||
|
||||
def test_assert_non_singular_raises_when_singular(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1., 2., 0.])
|
||||
with self.assertRaisesOpError("was singular"):
|
||||
operator.assert_non_singular().run()
|
||||
|
||||
def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1. + 0J])
|
||||
operator.assert_self_adjoint().run() # Should not fail
|
||||
|
||||
def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1. + 1J])
|
||||
with self.assertRaisesOpError("not self-adjoint"):
|
||||
operator.assert_self_adjoint().run()
|
||||
|
||||
def test_float16_apply(self):
|
||||
# float16 cannot be tested by base test class because tf.matrix_solve does
|
||||
# not work with float16.
|
||||
with self.test_session():
|
||||
multiplier = rng.rand(3).astype(np.float16)
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=multiplier)
|
||||
x = rng.randn(2, 3).astype(np.float16)
|
||||
y = operator.apply(x)
|
||||
self.assertAllClose(multiplier[..., None, None] * x, y.eval())
|
||||
|
||||
def test_non_scalar_num_rows_raises_static(self):
|
||||
# Many "test_...num_rows" tests are performed in LinearOperatorIdentity.
|
||||
with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
|
||||
linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=[2], multiplier=123.)
|
||||
|
||||
def test_wrong_matrix_dimensions_raises_static(self):
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=2.2)
|
||||
x = rng.randn(3, 3).astype(np.float32)
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
operator.apply(x)
|
||||
|
||||
def test_wrong_matrix_dimensions_raises_dynamic(self):
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
|
||||
with self.test_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows, multiplier=[1., 2], assert_proper_shapes=True)
|
||||
y = operator.apply(x)
|
||||
with self.assertRaisesOpError("Incompatible.*dimensions"):
|
||||
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
|
||||
|
||||
def test_broadcast_apply_and_solve(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
# test shapes that tf.batch_matmul cannot handle.
|
||||
# In particular, tf.batch_matmul does not broadcast.
|
||||
with self.test_session() as sess:
|
||||
# Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the
|
||||
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
|
||||
x = random_ops.random_normal(shape=(1, 2, 3, 4))
|
||||
|
||||
# operator is 2.2 * identity (with a batch shape).
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=3, multiplier=2.2 * array_ops.ones((2, 1)))
|
||||
|
||||
# Batch matrix of zeros with the broadcast shape of x and operator.
|
||||
zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
|
||||
|
||||
# Test apply
|
||||
expected = x * 2.2 + zeros
|
||||
operator_apply = operator.apply(x)
|
||||
self.assertAllEqual(operator_apply.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*sess.run([operator_apply, expected]))
|
||||
|
||||
# Test solve
|
||||
expected = x / 2.2 + zeros
|
||||
operator_solve = operator.solve(x)
|
||||
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*sess.run([operator_solve, expected]))
|
||||
|
||||
def test_broadcast_apply_and_solve_scalar_scale_multiplier(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
# test shapes that tf.batch_matmul cannot handle.
|
||||
# In particular, tf.batch_matmul does not broadcast.
|
||||
with self.test_session() as sess:
|
||||
# Given this x and LinearOperatorScaledIdentity shape of (3, 3), the
|
||||
# broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same
|
||||
# shape as x.
|
||||
x = random_ops.random_normal(shape=(1, 2, 3, 4))
|
||||
|
||||
# operator is 2.2 * identity (with a batch shape).
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=3, multiplier=2.2)
|
||||
|
||||
# Test apply
|
||||
expected = x * 2.2
|
||||
operator_apply = operator.apply(x)
|
||||
self.assertAllEqual(operator_apply.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*sess.run([operator_apply, expected]))
|
||||
|
||||
# Test solve
|
||||
expected = x / 2.2
|
||||
operator_solve = operator.solve(x)
|
||||
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*sess.run([operator_solve, expected]))
|
||||
|
||||
def test_is_x_flags(self):
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=1.,
|
||||
is_positive_definite=False, is_non_singular=True)
|
||||
self.assertFalse(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertTrue(operator.is_self_adjoint is None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import linalg as linalg_lib
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import ops
|
||||
@ -26,6 +28,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
linalg = linalg_lib
|
||||
random_seed.set_random_seed(23)
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
|
||||
class AssertZeroImagPartTest(test.TestCase):
|
||||
@ -88,5 +91,33 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
|
||||
z, message="ABC123").run()
|
||||
|
||||
|
||||
class DomainDimensionStubOperator(object):
|
||||
|
||||
def __init__(self, domain_dimension):
|
||||
self._domain_dimension = ops.convert_to_tensor(domain_dimension)
|
||||
|
||||
def domain_dimension_dynamic(self):
|
||||
return self._domain_dimension
|
||||
|
||||
|
||||
class AssertCompatibleMatrixDimensionsTest(test.TestCase):
|
||||
|
||||
def test_compatible_dimensions_do_not_raise(self):
|
||||
with self.test_session():
|
||||
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
# Should not raise
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
operator, x).run()
|
||||
|
||||
def test_incompatible_dimensions_raise(self):
|
||||
with self.test_session():
|
||||
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
with self.assertRaisesOpError("Incompatible matrix dimensions"):
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
operator, x).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -29,12 +29,52 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
__all__ = ["LinearOperatorIdentity",]
|
||||
__all__ = [
|
||||
"LinearOperatorIdentity",
|
||||
"LinearOperatorScaledIdentity",
|
||||
]
|
||||
|
||||
|
||||
class LinearOperatorIdentity(linear_operator.LinearOperator):
|
||||
class BaseLinearOperatorIdentity(linear_operator.LinearOperator):
|
||||
|
||||
def _check_num_rows_possibly_add_asserts(self):
|
||||
"""Static check of init arg `num_rows`, possibly add asserts."""
|
||||
# Possibly add asserts.
|
||||
if self._assert_proper_shapes:
|
||||
self._num_rows = control_flow_ops.with_dependencies(
|
||||
[
|
||||
check_ops.assert_rank(
|
||||
self._num_rows,
|
||||
0,
|
||||
message="Argument num_rows must be a 0-D Tensor."),
|
||||
check_ops.assert_non_negative(
|
||||
self._num_rows,
|
||||
message="Argument num_rows must be non-negative."),
|
||||
],
|
||||
self._num_rows)
|
||||
|
||||
# Static checks.
|
||||
if not self._num_rows.dtype.is_integer:
|
||||
raise TypeError("Argument num_rows must be integer type. Found:"
|
||||
" %s" % self._num_rows)
|
||||
|
||||
num_rows_static = self._num_rows_static
|
||||
|
||||
if num_rows_static is None:
|
||||
return # Cannot do any other static checks.
|
||||
|
||||
if num_rows_static.ndim != 0:
|
||||
raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
|
||||
" %s" % num_rows_static)
|
||||
|
||||
if num_rows_static < 0:
|
||||
raise ValueError("Argument num_rows must be non-negative. Found:"
|
||||
" %s" % num_rows_static)
|
||||
|
||||
|
||||
class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
"""`LinearOperator` acting like a [batch] square identity matrix.
|
||||
|
||||
This operator acts like a [batch] identity matrix `A` with shape
|
||||
@ -273,12 +313,10 @@ class LinearOperatorIdentity(linear_operator.LinearOperator):
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
# Note that adjoint has no effect since this matrix is self-adjoint.
|
||||
if x.dtype != self.dtype:
|
||||
raise TypeError(
|
||||
"Expected argument 'x' to have dtype %s. Found: %s"
|
||||
% (self.dtype, x))
|
||||
if self._assert_proper_shapes:
|
||||
x = self._assert_compatible_matrix_dimensions(x)
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
self, x)
|
||||
x = control_flow_ops.with_dependencies([aps], x)
|
||||
return self._possibly_broadcast_batch_shape(x)
|
||||
|
||||
def _determinant(self):
|
||||
@ -290,12 +328,6 @@ class LinearOperatorIdentity(linear_operator.LinearOperator):
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
return self._apply(rhs)
|
||||
|
||||
def _to_dense(self):
|
||||
return linalg_ops.eye(
|
||||
num_rows=self.domain_dimension_dynamic(),
|
||||
batch_shape=self.batch_shape_dynamic(),
|
||||
dtype=self.dtype)
|
||||
|
||||
def add_to_tensor(self, mat, name="add_to_tensor"):
|
||||
"""Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
|
||||
|
||||
@ -381,14 +413,238 @@ class LinearOperatorIdentity(linear_operator.LinearOperator):
|
||||
raise ValueError("Argument batch_shape must be non-negative. Found:"
|
||||
"%s" % self._batch_shape_static)
|
||||
|
||||
def _assert_compatible_matrix_dimensions(self, x):
|
||||
"""Check that an argument to solve/apply has proper domain dimension."""
|
||||
# Static checks are done in the base class. Only dynamic asserts here.
|
||||
assert_same_dd = check_ops.assert_equal(
|
||||
array_ops.shape(x)[-2],
|
||||
self.domain_dimension_dynamic(),
|
||||
message=(
|
||||
"Incompatible matrix dimensions. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
|
||||
return control_flow_ops.with_dependencies([assert_same_dd], x)
|
||||
class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
"""`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`.
|
||||
|
||||
This operator acts like a scaled [batch] identity matrix `A` with shape
|
||||
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
|
||||
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
|
||||
a scaled version of the `N x N` identity matrix.
|
||||
|
||||
`LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier`
|
||||
(a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the
|
||||
`multiplier` determines the scale for each batch member.
|
||||
|
||||
```python
|
||||
# Create a 2 x 2 scaled identity matrix.
|
||||
operator = LinearOperatorIdentity(num_rows=2, multiplier=3.)
|
||||
|
||||
operator.to_dense()
|
||||
==> [[3., 0.]
|
||||
[0., 3.]]
|
||||
|
||||
operator.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator.log_determinant()
|
||||
==> 2 * Log[3]
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
operator.apply(x)
|
||||
==> 3 * x
|
||||
|
||||
y = tf.random_normal(shape=[3, 2, 4])
|
||||
# Note that y.shape is compatible with operator.shape because operator.shape
|
||||
# is broadcast to [3, 2, 2].
|
||||
x = operator.solve(y)
|
||||
==> 3 * x
|
||||
|
||||
# Create a 2-batch of 2x2 identity matrices
|
||||
operator = LinearOperatorIdentity(num_rows=2, multiplier=5.)
|
||||
operator.to_dense()
|
||||
==> [[[5., 0.]
|
||||
[0., 5.]],
|
||||
[[5., 0.]
|
||||
[0., 5.]]]
|
||||
|
||||
x = ... Shape [2, 2, 3]
|
||||
operator.apply(x)
|
||||
==> 5 * x
|
||||
|
||||
# Here the operator and x have different batch_shape, and are broadcast.
|
||||
x = ... Shape [1, 2, 3]
|
||||
operator.apply(x)
|
||||
==> 5 * x
|
||||
```
|
||||
|
||||
### Shape compatibility
|
||||
|
||||
This operator acts on [batch] matrix with compatible shape.
|
||||
`x` is a batch matrix with compatible shape for `apply` and `solve` if
|
||||
|
||||
```
|
||||
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
|
||||
x.shape = [C1,...,Cc] + [N, R],
|
||||
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
* `operator.apply(x)` is `O(D1*...*Dd*N*R)`
|
||||
* `operator.solve(x)` is `O(D1*...*Dd*N*R)`
|
||||
* `operator.determinant()` is `O(D1*...*Dd)`
|
||||
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite`.
|
||||
These have the following meaning
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
runtime assert. For example, finite floating point precision may result
|
||||
in these promises being violated.
|
||||
* If `is_X == False`, callers should expect the operator to not have `X`.
|
||||
* If `is_X == None` (the default), callers should have no expectation either
|
||||
way.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_rows,
|
||||
multiplier,
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
assert_proper_shapes=False,
|
||||
name="LinearOperatorScaledIdentity"):
|
||||
"""Initialize a `LinearOperatorScaledIdentity`.
|
||||
|
||||
The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
|
||||
determines the size of each identity matrix, and a `multiplier`,
|
||||
which defines `dtype`, batch shape, and scale of each matrix.
|
||||
|
||||
This operator is able to broadcast the leading (batch) dimensions.
|
||||
|
||||
Args:
|
||||
num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
|
||||
corresponding identity matrix.
|
||||
multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar).
|
||||
is_non_singular: Expect that this operator is non-singular.
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite.
|
||||
assert_proper_shapes: Python `bool`. If `False`, only perform static
|
||||
checks that initialization and method arguments have proper shape.
|
||||
If `True`, and static checks are inconclusive, add asserts to the graph.
|
||||
name: A name for this `LinearOperator`
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_rows` is determined statically to be non-scalar, or
|
||||
negative.
|
||||
"""
|
||||
self._assert_proper_shapes = assert_proper_shapes
|
||||
|
||||
with ops.name_scope(name, values=[multiplier, num_rows]):
|
||||
self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier")
|
||||
|
||||
super(LinearOperatorScaledIdentity, self).__init__(
|
||||
dtype=self._multiplier.dtype,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
name=name)
|
||||
|
||||
# Shape [B1,...Bb, 1, 1]
|
||||
self._multiplier_matrix = array_ops.expand_dims(
|
||||
array_ops.expand_dims(self.multiplier, -1), -1)
|
||||
self._multiplier_matrix_conj = math_ops.conj(
|
||||
self._multiplier_matrix)
|
||||
self._abs_multiplier = math_ops.abs(self.multiplier)
|
||||
|
||||
self._num_rows = linear_operator_util.shape_tensor(
|
||||
num_rows, name="num_rows")
|
||||
self._num_rows_static = tensor_util.constant_value(self._num_rows)
|
||||
self._check_num_rows_possibly_add_asserts()
|
||||
self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype)
|
||||
self._num_rows_cast_to_real_dtype = math_ops.cast(
|
||||
self._num_rows, self.dtype.real_dtype)
|
||||
|
||||
def _shape(self):
|
||||
matrix_shape = tensor_shape.TensorShape(
|
||||
(self._num_rows_static, self._num_rows_static))
|
||||
|
||||
batch_shape = self.multiplier.get_shape()
|
||||
return batch_shape.concatenate(matrix_shape)
|
||||
|
||||
def _shape_dynamic(self):
|
||||
matrix_shape = array_ops.stack(
|
||||
(self._num_rows, self._num_rows), axis=0)
|
||||
|
||||
batch_shape = array_ops.shape(self.multiplier)
|
||||
return array_ops.concat((batch_shape, matrix_shape), 0)
|
||||
|
||||
def _assert_non_singular(self):
|
||||
return check_ops.assert_positive(
|
||||
math_ops.abs(self.multiplier),
|
||||
message="LinearOperator was singular")
|
||||
|
||||
def _assert_positive_definite(self):
|
||||
return check_ops.assert_positive(
|
||||
math_ops.real(self.multiplier),
|
||||
message="LinearOperator was not positive definite.")
|
||||
|
||||
def _assert_self_adjoint(self):
|
||||
imag_multiplier = math_ops.imag(self.multiplier)
|
||||
return check_ops.assert_equal(
|
||||
array_ops.zeros_like(imag_multiplier),
|
||||
imag_multiplier,
|
||||
message="LinearOperator was not self-adjoint")
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
matrix = self._multiplier_matrix
|
||||
if self._assert_proper_shapes:
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
self, x)
|
||||
x = control_flow_ops.with_dependencies([aps], x)
|
||||
return x * matrix
|
||||
|
||||
def _determinant(self):
|
||||
return self.multiplier ** self._num_rows_cast_to_dtype
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
return self._num_rows_cast_to_real_dtype * math_ops.log(
|
||||
self._abs_multiplier)
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
matrix = self._multiplier_matrix
|
||||
if self._assert_proper_shapes:
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
self, rhs)
|
||||
rhs = control_flow_ops.with_dependencies([aps], rhs)
|
||||
return rhs / matrix
|
||||
|
||||
def add_to_tensor(self, mat, name="add_to_tensor"):
|
||||
"""Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
|
||||
|
||||
Args:
|
||||
mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with broadcast shape and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name, values=[mat]):
|
||||
# Shape [B1,...,Bb, 1]
|
||||
multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
|
||||
|
||||
# Shape [C1,...,Cc, M, M]
|
||||
mat = ops.convert_to_tensor(mat, name="mat")
|
||||
|
||||
# Shape [C1,...,Cc, M]
|
||||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
|
||||
# multiplier_vector broadcasts here.
|
||||
new_diag = multiplier_vector + mat_diag
|
||||
|
||||
return array_ops.matrix_set_diag(mat, new_diag)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
"""The [batch] scalar `Tensor`, `c` in `cI`."""
|
||||
return self._multiplier
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -67,6 +68,32 @@ def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
|
||||
return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
|
||||
|
||||
|
||||
def assert_compatible_matrix_dimensions(operator, x):
|
||||
"""Assert that an argument to solve/apply has proper domain dimension.
|
||||
|
||||
If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then
|
||||
`operator.apply(x)` is defined only if `N = Q`. This `Op` returns an
|
||||
`Assert` that "fires" if this is not the case. Static checks are already
|
||||
done by the base class `LinearOperator`.
|
||||
|
||||
Args:
|
||||
operator: `LinearOperator`.
|
||||
x: `Tensor`.
|
||||
|
||||
Returns:
|
||||
`Assert` `Op`.
|
||||
"""
|
||||
# Static checks are done in the base class. Only dynamic asserts here.
|
||||
assert_same_dd = check_ops.assert_equal(
|
||||
array_ops.shape(x)[-2],
|
||||
operator.domain_dimension_dynamic(),
|
||||
message=(
|
||||
"Incompatible matrix dimensions. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
|
||||
return assert_same_dd
|
||||
|
||||
|
||||
def shape_tensor(shape, name=None):
|
||||
"""Convert Tensor using default type, unless empty list or tuple."""
|
||||
# Works just like random_ops._ShapeTensor.
|
||||
|
@ -795,7 +795,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
|
||||
The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`.
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` or `table.init.run()` once.
|
||||
`tf.tables_initializer.run()` or `table.init.run()` once.
|
||||
|
||||
Sample Usages:
|
||||
|
||||
@ -813,7 +813,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
|
||||
vocabulary_file="test.txt", num_oov_buckets=1)
|
||||
ids = table.lookup(features)
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket
|
||||
```
|
||||
@ -893,7 +893,7 @@ def string_to_index_table_from_tensor(mapping,
|
||||
The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`.
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` or `table.init.run()` once.
|
||||
`tf.tables_initializer.run()` or `table.init.run()` once.
|
||||
|
||||
Elements in `mapping` cannot have duplicates, otherwise when executing the
|
||||
table initializer op, it will throw a `FailedPreconditionError`.
|
||||
@ -907,7 +907,7 @@ def string_to_index_table_from_tensor(mapping,
|
||||
features = tf.constant(["emerson", "lake", "and", "palmer"])
|
||||
ids = table.lookup(features)
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
ids.eval() ==> [0, 1, 4, 2]
|
||||
```
|
||||
@ -975,7 +975,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
|
||||
will throw a FailedPreconditionError.
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` once.
|
||||
`tf.tables_initializer.run()` once.
|
||||
|
||||
For example:
|
||||
|
||||
@ -985,7 +985,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
|
||||
ids = tf.contrib.lookup.string_to_index(
|
||||
feats, mapping=mapping_strings, default_value=-1)
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
ids.eval() ==> [0, 1, -1, 2]
|
||||
```
|
||||
@ -1022,7 +1022,7 @@ def index_to_string_table_from_file(vocabulary_file,
|
||||
(an out-of-vocabulary entry) is assigned the `default_value`
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` or `table.init.run()` once.
|
||||
`tf.tables_initializer.run()` or `table.init.run()` once.
|
||||
|
||||
Sample Usages:
|
||||
|
||||
@ -1040,7 +1040,7 @@ def index_to_string_table_from_file(vocabulary_file,
|
||||
vocabulary_file="test.txt", default_value="UNKNOWN")
|
||||
values = table.lookup(indices)
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
values.eval() ==> ["lake", "UNKNOWN"]
|
||||
```
|
||||
@ -1096,7 +1096,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
|
||||
(an out-of-vocabulary entry) is assigned the `default_value`
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` or `table.init.run()` once.
|
||||
`tf.tables_initializer.run()` or `table.init.run()` once.
|
||||
|
||||
Elements in `mapping` cannot have duplicates, otherwise when executing the
|
||||
table initializer op, it will throw a `FailedPreconditionError`.
|
||||
@ -1110,7 +1110,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
|
||||
mapping_string, default_value="UNKNOWN")
|
||||
values = table.lookup(indices)
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
values.eval() ==> ["lake", "UNKNOWN"]
|
||||
```
|
||||
@ -1159,7 +1159,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
|
||||
(an out-of-vocabulary entry) is assigned the `default_value`
|
||||
|
||||
The underlying table must be initialized by calling
|
||||
`tf.initialize_all_tables.run()` once.
|
||||
`tf.tables_initializer.run()` once.
|
||||
|
||||
For example:
|
||||
|
||||
@ -1169,7 +1169,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
|
||||
values = tf.contrib.lookup.index_to_string(
|
||||
indices, mapping=mapping_string, default_value="UNKNOWN")
|
||||
...
|
||||
tf.initialize_all_tables().run()
|
||||
tf.tables_initializer().run()
|
||||
|
||||
values.eval() ==> ["lake", "UNKNOWN"]
|
||||
```
|
||||
|
@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
|
||||
table3 = lookup_ops.HashTable(
|
||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(3, table1.size().eval())
|
||||
self.assertAllEqual(3, table2.size().eval())
|
||||
self.assertAllEqual(3, table3.size().eval())
|
||||
@ -1148,7 +1148,7 @@ class StringToIndexTableFromFile(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_string_to_index_table_from_file_with_default_value(self):
|
||||
@ -1160,7 +1160,7 @@ class StringToIndexTableFromFile(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, default_value), ids.eval())
|
||||
|
||||
def test_string_to_index_table_from_file_with_oov_buckets(self):
|
||||
@ -1172,7 +1172,7 @@ class StringToIndexTableFromFile(test.TestCase):
|
||||
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual(
|
||||
(
|
||||
1, # From vocabulary file.
|
||||
@ -1195,7 +1195,7 @@ class StringToIndexTableFromFile(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, -1, -1), ids.eval())
|
||||
self.assertEqual(2, table.size().eval())
|
||||
|
||||
@ -1222,7 +1222,7 @@ class StringToIndexTableFromFile(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, -1), ids.eval())
|
||||
self.assertEqual(3, table.size().eval())
|
||||
|
||||
@ -1255,7 +1255,7 @@ class StringToIndexTableFromTensor(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_string_to_index_table_from_tensor_with_default_value(self):
|
||||
@ -1266,7 +1266,7 @@ class StringToIndexTableFromTensor(test.TestCase):
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, default_value), ids.eval())
|
||||
|
||||
def test_string_to_index_table_from_tensor_with_only_oov_buckets(self):
|
||||
@ -1301,7 +1301,7 @@ class StringToIndexTest(test.TestCase):
|
||||
indices = lookup_ops.string_to_index(feats, mapping=mapping_strings)
|
||||
|
||||
self.assertRaises(errors_impl.OpError, indices.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllEqual((1, 2, -1), indices.eval())
|
||||
|
||||
@ -1312,7 +1312,7 @@ class StringToIndexTest(test.TestCase):
|
||||
indices = lookup_ops.string_to_index(feats, mapping=mapping_strings)
|
||||
|
||||
self.assertRaises(errors_impl.OpError,
|
||||
data_flow_ops.initialize_all_tables().run)
|
||||
data_flow_ops.tables_initializer().run)
|
||||
|
||||
def test_string_to_index_with_default_value(self):
|
||||
default_value = -42
|
||||
@ -1323,7 +1323,7 @@ class StringToIndexTest(test.TestCase):
|
||||
feats, mapping=mapping_strings, default_value=default_value)
|
||||
self.assertRaises(errors_impl.OpError, indices.eval)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, default_value), indices.eval())
|
||||
|
||||
|
||||
@ -1342,7 +1342,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
||||
vocabulary_file=vocabulary_file)
|
||||
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||
features.eval())
|
||||
|
||||
@ -1354,7 +1354,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
||||
vocabulary_file=vocabulary_file, default_value=default_value)
|
||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"salad", b"surgery", default_value),
|
||||
features.eval())
|
||||
|
||||
@ -1368,7 +1368,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
||||
default_value=default_value)
|
||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"salad", default_value, default_value),
|
||||
features.eval())
|
||||
|
||||
@ -1380,7 +1380,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
init = data_flow_ops.initialize_all_tables()
|
||||
init = data_flow_ops.tables_initializer()
|
||||
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"Invalid vocab_size", init.run)
|
||||
|
||||
@ -1392,7 +1392,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
|
||||
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
|
||||
|
||||
|
||||
@ -1407,7 +1407,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
||||
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
||||
features = table.lookup(indices)
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||
features.eval())
|
||||
@ -1419,7 +1419,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
||||
mapping=mapping_strings)
|
||||
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
||||
features = table.lookup(indices)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
|
||||
|
||||
def test_index_to_string_with_default_value(self):
|
||||
@ -1432,7 +1432,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
||||
features = table.lookup(indices)
|
||||
self.assertRaises(errors_impl.OpError, features.eval)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"salad", b"surgery", default_value),
|
||||
features.eval())
|
||||
|
||||
@ -1446,7 +1446,7 @@ class IndexToStringTest(test.TestCase):
|
||||
feats = lookup_ops.index_to_string(indices, mapping=mapping_strings)
|
||||
|
||||
self.assertRaises(errors_impl.OpError, feats.eval)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
||||
feats.eval())
|
||||
@ -1456,11 +1456,11 @@ class IndexToStringTest(test.TestCase):
|
||||
mapping_strings = constant_op.constant(["hello", "hello"])
|
||||
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
||||
feats = lookup_ops.index_to_string(indices, mapping=mapping_strings)
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
|
||||
|
||||
self.assertRaises(errors_impl.OpError,
|
||||
data_flow_ops.initialize_all_tables().run)
|
||||
data_flow_ops.tables_initializer().run)
|
||||
|
||||
def test_index_to_string_with_default_value(self):
|
||||
default_value = b"NONE"
|
||||
@ -1471,7 +1471,7 @@ class IndexToStringTest(test.TestCase):
|
||||
indices, mapping=mapping_strings, default_value=default_value)
|
||||
self.assertRaises(errors_impl.OpError, feats.eval)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
|
||||
|
||||
|
||||
@ -1615,7 +1615,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
||||
default_value,
|
||||
shared_name=shared_name)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
input_string = constant_op.constant(["brain", "salad", "tank"])
|
||||
|
||||
@ -1847,7 +1847,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
|
||||
hasher_spec=lookup_ops.StrongHashSpec((1, 2)),
|
||||
name="table2")
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
input_string = constant_op.constant(
|
||||
["fruit", "brain", "salad", "surgery", "UNK"])
|
||||
@ -1933,7 +1933,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
|
||||
default_value2),
|
||||
oov_buckets)
|
||||
|
||||
data_flow_ops.initialize_all_tables().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
input_string_1 = constant_op.constant(
|
||||
["brain", "salad", "surgery", "UNK"])
|
||||
|
@ -137,7 +137,7 @@ $(shell mkdir -p $(DEPDIR) >/dev/null)
|
||||
|
||||
# Settings for the target compiler.
|
||||
CXX := $(CC_PREFIX) gcc
|
||||
OPTFLAGS := -O2
|
||||
OPTFLAGS := -O2 -march=native
|
||||
CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS)
|
||||
LDFLAGS := \
|
||||
-L/usr/local/lib
|
||||
|
@ -197,14 +197,13 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
|
||||
continue;
|
||||
}
|
||||
}
|
||||
SignatureDef signature_def;
|
||||
const Signature signature = it_named_signature.second;
|
||||
if (IsRegressionSignature(signature)) {
|
||||
(*meta_graph_def->mutable_signature_def())[key] =
|
||||
BuildRegressionSignatureDef(signature.regression_signature());
|
||||
} else if (IsClassificationSignature(signature)) {
|
||||
(*meta_graph_def->mutable_signature_def())[key] = signature_def;
|
||||
BuildClassificationSignatureDef(signature.classification_signature());
|
||||
(*meta_graph_def->mutable_signature_def())[key] =
|
||||
BuildClassificationSignatureDef(signature.classification_signature());
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "Named signature up-conversion to SignatureDef is only supported "
|
||||
|
@ -213,18 +213,32 @@ TEST(BundleShimTest, DefaultSignatureGeneric) {
|
||||
EXPECT_EQ(0, meta_graph_def.signature_def_size());
|
||||
}
|
||||
|
||||
// Helper function to validate that the SignatureDef found in the MetaGraphDef
|
||||
// with the provided key has the expected string representation.
|
||||
void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
|
||||
const string& expected_string_signature_def) {
|
||||
tensorflow::SignatureDef expected_signature;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
|
||||
&expected_signature));
|
||||
auto iter = meta_graph_def.signature_def().find(key);
|
||||
ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
|
||||
EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
|
||||
}
|
||||
|
||||
TEST(BundleShimTest, NamedRegressionSignatures) {
|
||||
Signatures signatures;
|
||||
|
||||
RegressionSignature* inputs_regression_signature =
|
||||
(*signatures.mutable_named_signatures())[kRegressInputs]
|
||||
RegressionSignature* foo_regression_signature =
|
||||
(*signatures.mutable_named_signatures())["foo"]
|
||||
.mutable_regression_signature();
|
||||
inputs_regression_signature->mutable_input()->set_tensor_name("foo-input");
|
||||
foo_regression_signature->mutable_input()->set_tensor_name("foo-input");
|
||||
foo_regression_signature->mutable_output()->set_tensor_name("foo-output");
|
||||
|
||||
RegressionSignature* outputs_regression_signature =
|
||||
(*signatures.mutable_named_signatures())[kRegressOutputs]
|
||||
RegressionSignature* bar_regression_signature =
|
||||
(*signatures.mutable_named_signatures())["bar"]
|
||||
.mutable_regression_signature();
|
||||
outputs_regression_signature->mutable_output()->set_tensor_name("foo-output");
|
||||
bar_regression_signature->mutable_input()->set_tensor_name("bar-input");
|
||||
bar_regression_signature->mutable_output()->set_tensor_name("bar-output");
|
||||
|
||||
MetaGraphDef meta_graph_def;
|
||||
(*meta_graph_def.mutable_collection_def())[kSignaturesKey]
|
||||
@ -232,7 +246,36 @@ TEST(BundleShimTest, NamedRegressionSignatures) {
|
||||
->add_value()
|
||||
->PackFrom(signatures);
|
||||
ConvertSignaturesToSignatureDefs(&meta_graph_def);
|
||||
EXPECT_EQ(2, meta_graph_def.signature_def_size());
|
||||
ASSERT_EQ(2, meta_graph_def.signature_def_size());
|
||||
|
||||
ValidateSignatureDef(meta_graph_def, "foo",
|
||||
"inputs { "
|
||||
" key: \"inputs\" "
|
||||
" value { "
|
||||
"name: \"foo-input\" "
|
||||
" } "
|
||||
"} "
|
||||
"outputs { "
|
||||
" key: \"outputs\" "
|
||||
" value { "
|
||||
" name: \"foo-output\" "
|
||||
" } "
|
||||
"} "
|
||||
"method_name: \"tensorflow/serving/regress\" ");
|
||||
ValidateSignatureDef(meta_graph_def, "bar",
|
||||
"inputs { "
|
||||
" key: \"inputs\" "
|
||||
" value { "
|
||||
"name: \"bar-input\" "
|
||||
" } "
|
||||
"} "
|
||||
"outputs { "
|
||||
" key: \"outputs\" "
|
||||
" value { "
|
||||
" name: \"bar-output\" "
|
||||
" } "
|
||||
"} "
|
||||
"method_name: \"tensorflow/serving/regress\" ");
|
||||
}
|
||||
|
||||
TEST(BundleShimTest, NamedClassificationSignatures) {
|
||||
@ -257,7 +300,36 @@ TEST(BundleShimTest, NamedClassificationSignatures) {
|
||||
->add_value()
|
||||
->PackFrom(signatures);
|
||||
ConvertSignaturesToSignatureDefs(&meta_graph_def);
|
||||
EXPECT_EQ(2, meta_graph_def.signature_def_size());
|
||||
ASSERT_EQ(2, meta_graph_def.signature_def_size());
|
||||
|
||||
ValidateSignatureDef(meta_graph_def, "foo",
|
||||
"inputs { "
|
||||
" key: \"inputs\" "
|
||||
" value { "
|
||||
"name: \"foo-input\" "
|
||||
" } "
|
||||
"} "
|
||||
"outputs { "
|
||||
" key: \"classes\" "
|
||||
" value { "
|
||||
" name: \"foo-classes\" "
|
||||
" } "
|
||||
"} "
|
||||
"method_name: \"tensorflow/serving/classify\" ");
|
||||
ValidateSignatureDef(meta_graph_def, "bar",
|
||||
"inputs { "
|
||||
" key: \"inputs\" "
|
||||
" value { "
|
||||
"name: \"bar-input\" "
|
||||
" } "
|
||||
"} "
|
||||
"outputs { "
|
||||
" key: \"scores\" "
|
||||
" value { "
|
||||
" name: \"bar-scores\" "
|
||||
" } "
|
||||
"} "
|
||||
"method_name: \"tensorflow/serving/classify\" ");
|
||||
}
|
||||
|
||||
// Checks the Predict SignatureDef created when the named signatures have
|
||||
|
@ -627,7 +627,7 @@ def train(train_op,
|
||||
init_feed_dict: A feed dictionary to use when executing the `init_op`.
|
||||
local_init_op: The local initialization operation. If left to its default
|
||||
value, then the session is initialized by calling
|
||||
`tf.local_variables_initializer()` and `tf.initialize_all_tables()`.
|
||||
`tf.local_variables_initializer()` and `tf.tables_initializer()`.
|
||||
init_fn: An optional callable to be executed after `init_op` is called. The
|
||||
callable must accept one argument, the session being initialized.
|
||||
ready_op: Operation to check if the model is ready to use. If left to its
|
||||
@ -697,7 +697,7 @@ def train(train_op,
|
||||
if local_init_op == _USE_DEFAULT:
|
||||
local_init_op = control_flow_ops.group(
|
||||
tf_variables.local_variables_initializer(),
|
||||
data_flow_ops.initialize_all_tables())
|
||||
data_flow_ops.tables_initializer())
|
||||
|
||||
if sync_optimizer is not None and isinstance(
|
||||
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||
|
@ -115,8 +115,10 @@ def bucket(tensors,
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
batch_size: The new batch size pulled from the queue (all queues will have
|
||||
the same size). If a list is passed in then each bucket will have a
|
||||
different batch_size.
|
||||
(python int, int32 scalar or iterable of integers of length num_buckets).
|
||||
num_buckets: A python integer, the number of buckets.
|
||||
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||
@ -145,8 +147,17 @@ def bucket(tensors,
|
||||
|
||||
Raises:
|
||||
ValueError: If the `shapes` are not specified, and cannot be
|
||||
inferred from the elements of `tensors`.
|
||||
inferred from the elements of `tensors` or if batch_size is a sequence
|
||||
but it's length != num_buckets.
|
||||
"""
|
||||
batch_size_per_bucket = False
|
||||
if isinstance(batch_size, (list, tuple)):
|
||||
batch_size_per_bucket = True
|
||||
if len(batch_size) != num_buckets:
|
||||
raise ValueError(
|
||||
"If batch_size is a list it must have num_buckets elements")
|
||||
else:
|
||||
batch_size = [batch_size] * num_buckets
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
with ops.name_scope(name, "bucket", tensor_list) as name:
|
||||
tensor_list = _validate_bucket(tensor_list)
|
||||
@ -154,11 +165,12 @@ def bucket(tensors,
|
||||
tensor_list, enqueue_many=False, keep_input=constant_op.constant(True))
|
||||
|
||||
# Round-trip batch_size to a tensor, and possibly back
|
||||
batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int32, name="batch_size")
|
||||
static_batch_size = tensor_util.constant_value(batch_size)
|
||||
batch_size = (static_batch_size if static_batch_size is not None else
|
||||
batch_size)
|
||||
for i, bucket_batch_size in enumerate(batch_size):
|
||||
bucket_batch_size = ops.convert_to_tensor(
|
||||
bucket_batch_size, dtype=dtypes.int32, name="batch_size")
|
||||
static_batch_size = tensor_util.constant_value(bucket_batch_size)
|
||||
batch_size[i] = (static_batch_size if static_batch_size is not None else
|
||||
bucket_batch_size)
|
||||
|
||||
types = _dtypes([tensor_list])
|
||||
shapes = _shapes([tensor_list], shapes, enqueue_many=False)
|
||||
@ -179,8 +191,9 @@ def bucket(tensors,
|
||||
shared_name=shared_name_i,
|
||||
name="bucket_queue_%d" % i))
|
||||
|
||||
maybe_static_batch_size = (None if allow_smaller_final_batch else
|
||||
static_batch_size)
|
||||
maybe_static_batch_size = (
|
||||
None if (allow_smaller_final_batch or batch_size_per_bucket)
|
||||
else static_batch_size)
|
||||
|
||||
bucket_shapes = [
|
||||
tensor_shape.vector(maybe_static_batch_size).concatenate(s)
|
||||
@ -229,9 +242,9 @@ def bucket(tensors,
|
||||
enqueues_to_top = [
|
||||
top_queue.enqueue(
|
||||
[constant_op.constant(i)] + which_dequeue(q)(
|
||||
batch_size, name="read_bucket_%d" % i),
|
||||
bs, name="read_bucket_%d" % i),
|
||||
name="enqueue_from_bucket_%d" % i)
|
||||
for i, q in enumerate(bucket_queues)
|
||||
for i, (q, bs) in enumerate(zip(bucket_queues, batch_size))
|
||||
]
|
||||
|
||||
for i, q in enumerate(bucket_queues):
|
||||
@ -284,8 +297,10 @@ def bucket_by_sequence_length(input_length,
|
||||
input_length: `int32` scalar `Tensor`, the sequence length of tensors.
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
batch_size: The new batch size pulled from the queue (all queues will have
|
||||
the same size). If a list is passed in then each bucket will have a
|
||||
different batch_size.
|
||||
(python int, int32 scalar or iterable of integers of length num_buckets).
|
||||
bucket_boundaries: int list, increasing non-negative numbers.
|
||||
The edges of the buckets to use when bucketing tensors. Two extra buckets
|
||||
are created, one for `input_length < bucket_boundaries[0]` and
|
||||
@ -317,7 +332,8 @@ def bucket_by_sequence_length(input_length,
|
||||
Raises:
|
||||
TypeError: if `bucket_boundaries` is not a list of python integers.
|
||||
ValueError: if `bucket_boundaries` is empty or contains non-increasing
|
||||
values.
|
||||
values or if batch_size is a list and it's length doesn't equal the number
|
||||
of buckets.
|
||||
"""
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
if not isinstance(bucket_boundaries, (list, tuple)):
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -139,6 +140,51 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
|
||||
|
||||
def testBatchSizePerBucket(self):
|
||||
which_bucket = control_flow_ops.cond(self.scalar_int < 5,
|
||||
lambda: constant_op.constant(0),
|
||||
lambda: constant_op.constant(1))
|
||||
batch_sizes = [5, 10]
|
||||
bucketed_dynamic = bucket_ops.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=which_bucket,
|
||||
num_buckets=2,
|
||||
batch_size=batch_sizes,
|
||||
num_threads=1,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[None], [None, None], [None, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(15):
|
||||
self.enqueue_inputs(sess, {
|
||||
self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]
|
||||
})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get two minibatches (one with small values, one with large).
|
||||
bucketed_values_0 = sess.run(bucketed_dynamic)
|
||||
bucketed_values_1 = sess.run(bucketed_dynamic)
|
||||
|
||||
# Figure out which output has the small values
|
||||
if bucketed_values_0[0] < 5:
|
||||
bucketed_values_large, bucketed_values_small = (bucketed_values_1,
|
||||
bucketed_values_0)
|
||||
else:
|
||||
bucketed_values_small, bucketed_values_large = (bucketed_values_0,
|
||||
bucketed_values_1)
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, bucketed_values_small[0])
|
||||
self.assertAllEqual(1, bucketed_values_large[0])
|
||||
|
||||
# Check that the batch sizes differ per bucket
|
||||
self.assertEqual(5, len(bucketed_values_small[1][0]))
|
||||
self.assertEqual(10, len(bucketed_values_large[1][0]))
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
which_bucket = (self.scalar_int % 2)
|
||||
bucketed_dynamic = bucket_ops.bucket(
|
||||
|
@ -284,10 +284,10 @@ def add_gradients_summaries(grads_and_vars):
|
||||
else:
|
||||
grad_values = grad
|
||||
summaries.append(
|
||||
summary.histogram_summary(var.op.name + ':gradient', grad_values))
|
||||
summary.histogram(var.op.name + '_gradient', grad_values))
|
||||
summaries.append(
|
||||
summary.histogram_summary(var.op.name + ':gradient_norm',
|
||||
clip_ops.global_norm([grad_values])))
|
||||
summary.histogram(var.op.name + '_gradient_norm',
|
||||
clip_ops.global_norm([grad_values])))
|
||||
else:
|
||||
logging.info('Var %s has no gradient', var.op.name)
|
||||
|
||||
|
@ -64,6 +64,7 @@ load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"if_ios",
|
||||
"if_not_mobile",
|
||||
"if_not_windows",
|
||||
"tf_copts",
|
||||
"tf_cc_test",
|
||||
@ -100,6 +101,8 @@ load(
|
||||
"tf_additional_test_deps",
|
||||
"tf_additional_test_srcs",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
"tf_additional_cloud_op_deps",
|
||||
"tf_additional_cloud_kernel_deps",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
@ -454,6 +457,16 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cloud_ops_op_lib",
|
||||
srcs = ["ops/cloud_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
visibility = ["//visibility:public"],
|
||||
@ -484,7 +497,7 @@ cc_library(
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
":word2vec_ops",
|
||||
],
|
||||
] + tf_additional_cloud_op_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -606,7 +619,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:string",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"//tensorflow/core/kernels:word2vec_kernels",
|
||||
] + if_not_windows([
|
||||
] + tf_additional_cloud_kernel_deps() + if_not_windows([
|
||||
"//tensorflow/core/kernels:fact_op",
|
||||
"//tensorflow/core/kernels:array_not_windows",
|
||||
"//tensorflow/core/kernels:math_not_windows",
|
||||
|
@ -111,10 +111,10 @@ class Device : public DeviceBase {
|
||||
//
|
||||
// 'library' provides access to the function library which is shared
|
||||
// between all device partitions.
|
||||
// 'graphdef' supplies the partition of the graph assigned to this
|
||||
// 'graph' supplies the partition of the graph assigned to this
|
||||
// device.
|
||||
virtual Status MaybeRewriteGraph(const FunctionDefLibrary& /*library*/,
|
||||
GraphDef* /*graphdef*/) {
|
||||
std::unique_ptr<Graph>* /*graph*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,7 @@ class DeviceAttributes;
|
||||
|
||||
class DeviceMgr {
|
||||
public:
|
||||
// Takes ownership of each device in 'devices'.
|
||||
// TODO(zhifengc): Other initialization information.
|
||||
explicit DeviceMgr(const std::vector<Device*>& devices);
|
||||
~DeviceMgr();
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -940,7 +941,7 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
GraphOptimizer optimizer(optimizer_opts);
|
||||
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
||||
const string& partition_name = iter->first;
|
||||
Graph* partition_graph = iter->second.get();
|
||||
std::unique_ptr<Graph>& partition_graph = iter->second;
|
||||
const int graph_def_version = partition_graph->versions().producer();
|
||||
|
||||
Device* device;
|
||||
@ -980,24 +981,23 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
};
|
||||
params.node_outputs_cb = node_outputs_callback_;
|
||||
|
||||
partition_graph = iter->second.release();
|
||||
optimizer.Optimize(lib, options_.env, device, &partition_graph);
|
||||
optimizer.Optimize(lib, options_.env, device, &iter->second);
|
||||
|
||||
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
|
||||
if (run_state_args->debugger_state) {
|
||||
TF_RETURN_IF_ERROR(run_state_args->debugger_state->DecorateGraphForDebug(
|
||||
partition_graph, params.device));
|
||||
partition_graph.get(), params.device));
|
||||
}
|
||||
iter->second.reset(partition_graph);
|
||||
|
||||
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
|
||||
device->name(), partition_graph));
|
||||
device->name(),
|
||||
partition_graph.get()));
|
||||
// NewLocalExecutor takes ownership of partition_graph.
|
||||
item->graph = partition_graph;
|
||||
item->graph = partition_graph.get();
|
||||
item->executor = nullptr;
|
||||
Executor* executor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewLocalExecutor(params, iter->second.release(), &executor));
|
||||
NewLocalExecutor(params, partition_graph.release(), &executor));
|
||||
item->executor.reset(executor);
|
||||
}
|
||||
|
||||
@ -1118,12 +1118,31 @@ Status DirectSession::CreateGraphs(
|
||||
}
|
||||
}
|
||||
|
||||
Status s;
|
||||
for (auto&& partition : partitions) {
|
||||
const string& partition_name = partition.first;
|
||||
for (const auto& partition : partitions) {
|
||||
std::unique_ptr<Graph> device_graph(
|
||||
new Graph(client_graph->flib_def.get()));
|
||||
GraphConstructorOptions device_opts;
|
||||
// There are internal operations (e.g., send/recv) that we now allow.
|
||||
device_opts.allow_internal_ops = true;
|
||||
device_opts.expect_device_spec = true;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
|
||||
device_graph.get()));
|
||||
outputs->emplace(partition.first, std::move(device_graph));
|
||||
}
|
||||
|
||||
GraphDef* graph_def = &partition.second;
|
||||
VLOG(2) << "Created " << ProtoDebugString(*graph_def) << " for "
|
||||
GraphOptimizationPassOptions optimization_options;
|
||||
optimization_options.session_options = &options_;
|
||||
optimization_options.flib_def = client_graph->flib_def.get();
|
||||
optimization_options.partition_graphs = outputs;
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
|
||||
|
||||
Status s;
|
||||
for (auto& partition : *outputs) {
|
||||
const string& partition_name = partition.first;
|
||||
std::unique_ptr<Graph>* graph = &partition.second;
|
||||
|
||||
VLOG(2) << "Created " << DebugString(graph->get()) << " for "
|
||||
<< partition_name;
|
||||
|
||||
// Give the device an opportunity to rewrite its subgraph.
|
||||
@ -1134,20 +1153,10 @@ Status DirectSession::CreateGraphs(
|
||||
// may be possible use cases where a device may want to modify
|
||||
// function definitions - in which case the library would need to be
|
||||
// replicated per device.
|
||||
s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph_def);
|
||||
s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
|
||||
if (!s.ok()) {
|
||||
break;
|
||||
}
|
||||
std::unique_ptr<Graph> device_graph(
|
||||
new Graph(client_graph->flib_def.get()));
|
||||
GraphConstructorOptions device_opts;
|
||||
// There are internal operations (e.g., send/recv) that we now
|
||||
// allow.
|
||||
device_opts.allow_internal_ops = true;
|
||||
device_opts.expect_device_spec = true;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get()));
|
||||
outputs->emplace(partition_name, std::move(device_graph));
|
||||
}
|
||||
*flib_def = std::move(client_graph->flib_def);
|
||||
return s;
|
||||
|
@ -463,7 +463,7 @@ void DumpGraph(StringPiece label, const Graph* g) {
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
|
||||
OptimizerOptions opts;
|
||||
opts.set_do_common_subexpression_elimination(true);
|
||||
opts.set_do_function_inlining(true);
|
||||
@ -475,16 +475,12 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
|
||||
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
CHECK_NOTNULL(fbody);
|
||||
Graph* g = new Graph(lib_def_);
|
||||
CopyGraph(*fbody->graph, g);
|
||||
std::unique_ptr<Graph> g(new Graph(lib_def_));
|
||||
CopyGraph(*fbody->graph, g.get());
|
||||
|
||||
optimizer_.Optimize(this, env(), device(), &g);
|
||||
auto s = EnsureMemoryTypes(DeviceType(device()->device_type()),
|
||||
device()->name(), g);
|
||||
if (!s.ok()) {
|
||||
delete g;
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
|
||||
device()->name(), g.get()));
|
||||
|
||||
// Creates an executor based on the g. This must be done without
|
||||
// holding mu_ because create_kernel_ calls back into the library.
|
||||
@ -495,11 +491,12 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
Graph* graph = g.get();
|
||||
Executor* exec;
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec));
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec));
|
||||
|
||||
*item = new Item;
|
||||
(*item)->graph = g;
|
||||
(*item)->graph = graph;
|
||||
(*item)->exec = exec;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -126,7 +127,7 @@ void DumpGraph(StringPiece label, const Graph* g);
|
||||
// OptimizeGraph mutates **g extensively and replaces '*g' with a
|
||||
// complete copy. Therefore, the caller should not keep any references
|
||||
// to nodes *g.
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g);
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
|
||||
|
||||
// Convert the Graph of a function to a GraphDef.
|
||||
//
|
||||
|
@ -342,9 +342,9 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
|
||||
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
||||
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
|
||||
test::function::XTimes16()});
|
||||
Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
|
||||
std::unique_ptr<Graph> g(GetFuncBody("XTimes16", {{"T", DT_FLOAT}}));
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
ExpandInlineFunctions(lib_, g);
|
||||
ExpandInlineFunctions(lib_, g.get());
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e0 = R"P(
|
||||
(n2:float) -> (n7:float) {
|
||||
@ -355,8 +355,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
||||
n7 = Mul[T=float](n6, n8)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(g));
|
||||
delete g;
|
||||
EXPECT_EQ(e0, DebugString(g.get()));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
|
||||
@ -380,15 +379,14 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
|
||||
// Return
|
||||
{{"o", "g:output"}});
|
||||
Init({test::function::Swap(), func});
|
||||
Graph* g = GetFuncBody("ManySwapsNodeDef", {});
|
||||
std::unique_ptr<Graph> g(GetFuncBody("ManySwapsNodeDef", {}));
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e0 = R"P(
|
||||
(n3:float, n2:float) -> (n3:float) {
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(g));
|
||||
delete g;
|
||||
EXPECT_EQ(e0, DebugString(g.get()));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||
@ -414,7 +412,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
|
||||
{{"o", "o:z:0"}});
|
||||
Init({test::function::Swap(), func});
|
||||
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
||||
std::unique_ptr<Graph> g(GetFuncBody("ManySwapsFirst", {}));
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
OptimizeGraph(lib_, &g);
|
||||
|
||||
@ -431,8 +429,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||
n6 = Add[T=float](n4, n5)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(g));
|
||||
delete g;
|
||||
EXPECT_EQ(e0, DebugString(g.get()));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
|
||||
@ -489,7 +486,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(f));
|
||||
delete f;
|
||||
auto g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}});
|
||||
std::unique_ptr<Graph> g(GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}));
|
||||
const char* e1 = R"P(
|
||||
(n4:float, n6:float) -> (n7:float) {
|
||||
n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
||||
@ -498,7 +495,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
n7 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Mul[T=float]](n4, n3, n6)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e1, DebugString(g));
|
||||
EXPECT_EQ(e1, DebugString(g.get()));
|
||||
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e2 = R"P(
|
||||
@ -512,9 +509,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
n9 = Reshape[T=float, Tshape=int32](n8, n6)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e2, DebugString(g));
|
||||
|
||||
delete g;
|
||||
EXPECT_EQ(e2, DebugString(g.get()));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
|
||||
@ -591,7 +586,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
|
||||
Init({test, grad});
|
||||
|
||||
Graph* g = GetFuncBody("TestGrad", {});
|
||||
std::unique_ptr<Graph> g(GetFuncBody("TestGrad", {}));
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
const char* e0 = R"P(
|
||||
(n4:float, n3:float) -> (n8:float, n6:float) {
|
||||
@ -601,9 +596,9 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
n8 = Identity[T=float](n5)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(g));
|
||||
EXPECT_EQ(e0, DebugString(g.get()));
|
||||
|
||||
ExpandInlineFunctions(lib_, g);
|
||||
ExpandInlineFunctions(lib_, g.get());
|
||||
const char* e1 = R"P(
|
||||
(n4:float, n3:float) -> (n8:float, n6:float) {
|
||||
n10 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
|
||||
@ -625,7 +620,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
n8 = Identity[T=float](n27)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e1, DebugString(g));
|
||||
EXPECT_EQ(e1, DebugString(g.get()));
|
||||
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e2 = R"P(
|
||||
@ -652,8 +647,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
n23 = Reshape[T=float, Tshape=int32](n22, n19)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e2, DebugString(g));
|
||||
delete g;
|
||||
EXPECT_EQ(e2, DebugString(g.get()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -21,6 +21,128 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Replaces occurrences of parallel_concat with the implementation based on
|
||||
// unsafe ops. Sets removed_any to true if any parallel_concats were removed;
|
||||
// leaves it untouched otherwise.
|
||||
// TODO(apassos) Use NodeBuilder.
|
||||
Status RemoveParallelConcat(bool* removed_any, Graph* g) {
|
||||
gtl::InlinedVector<Node*, 2> matches;
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->type_string() == "ParallelConcat") {
|
||||
matches.push_back(n);
|
||||
}
|
||||
}
|
||||
for (Node* n : matches) {
|
||||
AttrSlice n_attrs(n->def());
|
||||
auto make_node = [n, g, &n_attrs](string op) {
|
||||
NodeDef node;
|
||||
node.set_op(op);
|
||||
node.set_name(g->NewName(n->name()));
|
||||
node.set_device(n->def().device());
|
||||
string colo;
|
||||
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
|
||||
AddNodeAttr("_class", colo, &node);
|
||||
}
|
||||
return node;
|
||||
};
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
|
||||
TensorShapeProto shape;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
|
||||
// Add the constant shape input to the start node.
|
||||
NodeDef shape_node_def = make_node("Const");
|
||||
AddNodeAttr("dtype", DT_INT32, &shape_node_def);
|
||||
TensorProto shape_tensor;
|
||||
shape_tensor.set_dtype(DT_INT32);
|
||||
shape_tensor.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size());
|
||||
for (int i = 0; i < shape.dim_size(); ++i) {
|
||||
shape_tensor.add_int_val(shape.dim(i).size());
|
||||
}
|
||||
AddNodeAttr("value", shape_tensor, &shape_node_def);
|
||||
Status status = Status::OK();
|
||||
Node* shape_node = g->AddNode(shape_node_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
// Add the start node
|
||||
NodeDef start_def = make_node("_ParallelConcatStart");
|
||||
AddNodeAttr("dtype", dtype, &start_def);
|
||||
AddNodeAttr("Tshape", DT_INT32, &start_def);
|
||||
AddNodeAttr("init", false, &start_def);
|
||||
start_def.add_input(shape_node_def.name());
|
||||
Node* start = g->AddNode(start_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
// TODO(apassos): make the shape an attr of _ParallelStackBegin.
|
||||
g->AddEdge(shape_node, 0, start, 0);
|
||||
|
||||
// Add all the inplace_updates.
|
||||
std::vector<string> control_dependencies;
|
||||
std::vector<Node*> control_nodes;
|
||||
int i = 0;
|
||||
for (const Edge* input_edge : n->in_edges()) {
|
||||
if (input_edge->IsControlEdge()) {
|
||||
g->AddControlEdge(input_edge->src(), start);
|
||||
continue;
|
||||
}
|
||||
// Constant index for the update node.
|
||||
// TODO(apassos): make _ParallelStackUpdate take this as an attr.
|
||||
NodeDef update_idx_def = make_node("Const");
|
||||
AddNodeAttr("dtype", DT_INT64, &update_idx_def);
|
||||
TensorProto index_tensor;
|
||||
index_tensor.set_dtype(DT_INT64);
|
||||
index_tensor.mutable_tensor_shape()->add_dim()->set_size(1);
|
||||
index_tensor.add_int64_val(i);
|
||||
AddNodeAttr("value", index_tensor, &update_idx_def);
|
||||
Node* index = g->AddNode(update_idx_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
NodeDef update_def = make_node("_ParallelConcatUpdate");
|
||||
control_dependencies.push_back(update_def.name());
|
||||
AddNodeAttr("T", dtype, &update_def);
|
||||
AddNodeAttr("Tshape", DT_INT64, &update_def);
|
||||
update_def.add_input(start_def.name());
|
||||
update_def.add_input(update_idx_def.name());
|
||||
update_def.add_input(strings::StrCat(input_edge->src()->name(), ":",
|
||||
input_edge->src_output()));
|
||||
Node* update = g->AddNode(update_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
g->AddEdge(start, 0, update, 0);
|
||||
g->AddEdge(index, 0, update, 1);
|
||||
g->AddEdge(input_edge->src(), input_edge->src_output(), update, 2);
|
||||
control_nodes.push_back(update);
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
// Add the final identity.
|
||||
NodeDef identity_def = make_node("Identity");
|
||||
AddNodeAttr("T", dtype, &identity_def);
|
||||
identity_def.add_input(start_def.name());
|
||||
for (const string& s : control_dependencies) {
|
||||
identity_def.add_input(strings::StrCat("^", s));
|
||||
}
|
||||
Node* identity_node = g->AddNode(identity_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
g->AddEdge(start, 0, identity_node, 0);
|
||||
for (Node* inp : control_nodes) {
|
||||
g->AddControlEdge(inp, identity_node);
|
||||
}
|
||||
|
||||
// Remove the node and redirect edges.
|
||||
for (auto* e : n->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(identity_node, e->dst());
|
||||
} else {
|
||||
g->AddEdge(identity_node, 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
g->RemoveNode(n);
|
||||
*removed_any = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
|
||||
if (opts_.opt_level() >= OptimizerOptions::L1) {
|
||||
@ -32,8 +154,8 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
|
||||
GraphOptimizer::~GraphOptimizer() {}
|
||||
|
||||
void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
Device* device, Graph** graph) {
|
||||
Graph* g = *graph;
|
||||
Device* device, std::unique_ptr<Graph>* graph) {
|
||||
Graph* g = graph->get();
|
||||
DumpGraph("Initial", g);
|
||||
|
||||
bool changed = true;
|
||||
@ -44,6 +166,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
DumpGraph("RemoveListArrayConverter", g);
|
||||
changed = true;
|
||||
}
|
||||
auto s = RemoveParallelConcat(&changed, g);
|
||||
if (!s.ok()) {
|
||||
// TODO(apassos): figure out how to halt here.
|
||||
LOG(WARNING) << s;
|
||||
}
|
||||
if (opts_.do_function_inlining() && RemoveDeadNodes(g)) {
|
||||
DumpGraph("RemoveDeadNodes", g);
|
||||
changed = true;
|
||||
@ -78,11 +205,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
if (!changed) break;
|
||||
}
|
||||
|
||||
Graph* copy = new Graph(g->op_registry());
|
||||
CopyGraph(*g, copy);
|
||||
delete g;
|
||||
*graph = copy;
|
||||
DumpGraph("ReCopy", *graph);
|
||||
std::unique_ptr<Graph> copy(new Graph(g->op_registry()));
|
||||
CopyGraph(*g, copy.get());
|
||||
graph->swap(copy);
|
||||
|
||||
DumpGraph("ReCopy", graph->get());
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -35,7 +35,7 @@ class GraphOptimizer {
|
||||
// optimizers so that they can respect constraints if any, that should be
|
||||
// respected.
|
||||
void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||
Graph** graph);
|
||||
std::unique_ptr<Graph>* graph);
|
||||
|
||||
private:
|
||||
OptimizerOptions opts_;
|
||||
|
@ -39,8 +39,17 @@ struct GraphOptimizationPassOptions {
|
||||
const CostModel* cost_model = nullptr;
|
||||
|
||||
FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
|
||||
|
||||
// The graph to optimize, for optimization passes that run before
|
||||
// partitioning. Null for post-partitioning passes.
|
||||
// An optimization pass may replace *graph with a new graph object.
|
||||
std::unique_ptr<Graph>* graph;
|
||||
std::unique_ptr<Graph>* graph = nullptr;
|
||||
|
||||
// Graphs for each partition, if running post-partitioning. Optimization
|
||||
// passes may alter the graphs, but must not add or remove partitions.
|
||||
// Null for pre-partitioning passes.
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs =
|
||||
nullptr;
|
||||
};
|
||||
|
||||
// Optimization passes are implemented by inheriting from
|
||||
@ -64,6 +73,7 @@ class OptimizationPassRegistry {
|
||||
PRE_PLACEMENT, // after cost model assignment, before placement.
|
||||
POST_PLACEMENT, // after placement.
|
||||
POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints.
|
||||
POST_PARTITIONING, // after partitioning
|
||||
};
|
||||
|
||||
// Add an optimization pass to the registry.
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
@ -138,37 +139,50 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
|
||||
}
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
|
||||
for (const auto& partition : partitions) {
|
||||
std::unique_ptr<Graph> device_graph(new Graph(item->lib_def));
|
||||
GraphConstructorOptions device_opts;
|
||||
// There are internal operations (e.g., send/recv) that we now allow.
|
||||
device_opts.allow_internal_ops = true;
|
||||
device_opts.expect_device_spec = true;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
|
||||
device_graph.get()));
|
||||
partition_graphs.emplace(partition.first, std::move(device_graph));
|
||||
}
|
||||
|
||||
GraphOptimizationPassOptions optimization_options;
|
||||
optimization_options.flib_def = item->lib_def;
|
||||
optimization_options.partition_graphs = &partition_graphs;
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
|
||||
|
||||
LocalExecutorParams params;
|
||||
|
||||
Status s;
|
||||
item->units.reserve(partitions.size());
|
||||
item->graph_mgr = this;
|
||||
const auto& optimizer_opts = graph_options.optimizer_options();
|
||||
GraphOptimizer optimizer(optimizer_opts);
|
||||
for (auto&& p : partitions) {
|
||||
for (auto& p : partition_graphs) {
|
||||
const string& device_name = p.first;
|
||||
GraphDef* def = &p.second;
|
||||
std::unique_ptr<Graph>& subgraph = p.second;
|
||||
item->units.resize(item->units.size() + 1);
|
||||
ExecutionUnit* unit = &(item->units.back());
|
||||
|
||||
// Find the device.
|
||||
s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
|
||||
Status s =
|
||||
worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
|
||||
if (!s.ok()) {
|
||||
// Remove the empty unit from the item as the item destructor wants all
|
||||
// units to have valid devices.
|
||||
item->units.pop_back();
|
||||
break;
|
||||
return s;
|
||||
}
|
||||
|
||||
// Construct the subgraph.
|
||||
Graph* subgraph = new Graph(item->lib_def);
|
||||
// Give the device an opportunity to rewrite its subgraph.
|
||||
unit->device->MaybeRewriteGraph(gdef.library(), def);
|
||||
s = ConvertGraphDefToGraph(opts, *def, subgraph);
|
||||
if (!s.ok()) {
|
||||
delete subgraph;
|
||||
break;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
unit->device->MaybeRewriteGraph(gdef.library(), &subgraph));
|
||||
|
||||
// Top-level nodes in the graph uses the op segment to cache
|
||||
// kernels. Therefore, as long as the executor is alive, we need
|
||||
// to ensure the kernels cached for the session are alive.
|
||||
@ -178,7 +192,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
// Function library runtime.
|
||||
unit->lib = NewFunctionLibraryRuntime(
|
||||
worker_env_->device_mgr, worker_env_->env, unit->device,
|
||||
def->versions().producer(), item->lib_def,
|
||||
subgraph->versions().producer(), item->lib_def,
|
||||
graph_options.optimizer_options());
|
||||
|
||||
// Construct the root executor for the subgraph.
|
||||
@ -207,23 +221,18 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
};
|
||||
|
||||
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph);
|
||||
s = EnsureMemoryTypes(DeviceType(unit->device->device_type()),
|
||||
unit->device->name(), subgraph);
|
||||
if (!s.ok()) {
|
||||
delete subgraph;
|
||||
break;
|
||||
}
|
||||
s = NewLocalExecutor(params, subgraph, &unit->root);
|
||||
if (!s.ok()) {
|
||||
break;
|
||||
}
|
||||
unit->graph = subgraph;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EnsureMemoryTypes(DeviceType(unit->device->device_type()),
|
||||
unit->device->name(), subgraph.get()));
|
||||
unit->graph = subgraph.get();
|
||||
unit->build_cost_model = graph_options.build_cost_model();
|
||||
if (unit->build_cost_model > 0) {
|
||||
skip_cost_models_ = false;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewLocalExecutor(params, subgraph.release(), &unit->root));
|
||||
}
|
||||
return s;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
|
||||
|
@ -85,7 +85,7 @@ Status LocalMaster::PartialRunSetup(CallOptions* call_options,
|
||||
|
||||
Status LocalMaster::RunStep(CallOptions* call_options,
|
||||
RunStepRequestWrapper* request,
|
||||
RunStepResponse* response) {
|
||||
MutableRunStepResponseWrapper* response) {
|
||||
Notification n;
|
||||
Status ret;
|
||||
master_impl_->RunStep(call_options, request, response,
|
||||
@ -101,6 +101,10 @@ MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() {
|
||||
return new InMemoryRunStepRequest;
|
||||
}
|
||||
|
||||
MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() {
|
||||
return new InMemoryRunStepResponse;
|
||||
}
|
||||
|
||||
Status LocalMaster::CloseSession(CallOptions* call_options,
|
||||
const CloseSessionRequest* request,
|
||||
CloseSessionResponse* response) {
|
||||
|
@ -53,10 +53,12 @@ class LocalMaster : public MasterInterface {
|
||||
PartialRunSetupResponse* response) override;
|
||||
|
||||
Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
|
||||
RunStepResponse* response) override;
|
||||
MutableRunStepResponseWrapper* response) override;
|
||||
|
||||
MutableRunStepRequestWrapper* CreateRunStepRequest() override;
|
||||
|
||||
MutableRunStepResponseWrapper* CreateRunStepResponse() override;
|
||||
|
||||
Status CloseSession(CallOptions* call_options,
|
||||
const CloseSessionRequest* request,
|
||||
CloseSessionResponse* response) override;
|
||||
|
@ -356,7 +356,7 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
}
|
||||
|
||||
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
|
||||
RunStepResponse* resp, MyClosure done) {
|
||||
MutableRunStepResponseWrapper* resp, MyClosure done) {
|
||||
mu_.lock();
|
||||
uint64 start_time = env_->env->NowMicros();
|
||||
MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
|
||||
|
@ -50,7 +50,7 @@ class Master {
|
||||
PartialRunSetupResponse* resp, MyClosure done);
|
||||
|
||||
void RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
|
||||
RunStepResponse* resp, MyClosure done);
|
||||
MutableRunStepResponseWrapper* resp, MyClosure done);
|
||||
|
||||
void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp,
|
||||
MyClosure done);
|
||||
|
@ -24,10 +24,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Pure virtual interface for communicating with the TensorFlow Master service.
|
||||
// Abstract interface for communicating with the TensorFlow Master service.
|
||||
//
|
||||
// This interface is intended to support in-process master
|
||||
// implementations that do not require an RPC roundtrip.
|
||||
// This interface supports both RPC-based master implementations, and
|
||||
// in-process master implementations that do not require an RPC
|
||||
// roundtrip.
|
||||
class MasterInterface {
|
||||
public:
|
||||
virtual ~MasterInterface() {}
|
||||
@ -47,20 +48,36 @@ class MasterInterface {
|
||||
|
||||
virtual Status RunStep(CallOptions* call_options,
|
||||
RunStepRequestWrapper* request,
|
||||
RunStepResponse* response) = 0;
|
||||
MutableRunStepResponseWrapper* response) = 0;
|
||||
|
||||
virtual Status RunStep(CallOptions* call_options,
|
||||
const RunStepRequest* request,
|
||||
RunStepResponse* response) {
|
||||
std::unique_ptr<RunStepRequestWrapper> wrapped_request(
|
||||
new ProtoRunStepRequest(request));
|
||||
return RunStep(call_options, wrapped_request.get(), response);
|
||||
std::unique_ptr<MutableRunStepResponseWrapper> wrapped_response(
|
||||
new NonOwnedProtoRunStepResponse(response));
|
||||
return RunStep(call_options, wrapped_request.get(), wrapped_response.get());
|
||||
}
|
||||
|
||||
// Returns a request object for use in calls to
|
||||
// `RunStep()`. Ownership is transferred to the caller.
|
||||
//
|
||||
// The message returned from this method must only be used in a
|
||||
// `RunStep()` call on the same `MasterInterface` instance.
|
||||
virtual MutableRunStepRequestWrapper* CreateRunStepRequest() {
|
||||
return new MutableProtoRunStepRequest;
|
||||
}
|
||||
|
||||
// Returns a response object for use in calls to
|
||||
// `RunStep()`. Ownership is transferred to the caller.
|
||||
//
|
||||
// The message returned from this method must only be used in a
|
||||
// `RunStep()` call on the same `MasterInterface` instance.
|
||||
virtual MutableRunStepResponseWrapper* CreateRunStepResponse() {
|
||||
return new OwnedProtoRunStepResponse;
|
||||
}
|
||||
|
||||
virtual Status CloseSession(CallOptions* call_options,
|
||||
const CloseSessionRequest* request,
|
||||
CloseSessionResponse* response) = 0;
|
||||
@ -71,6 +88,15 @@ class MasterInterface {
|
||||
|
||||
virtual Status Reset(CallOptions* call_options, const ResetRequest* request,
|
||||
ResetResponse* response) = 0;
|
||||
|
||||
protected:
|
||||
// NOTE: This should only be called by implementations of this
|
||||
// interface whose CreateRunStepResponse() method returns a
|
||||
// proto-based wrappers for the RunStepResponse message.
|
||||
RunStepResponse* get_proto_from_wrapper(
|
||||
MutableRunStepResponseWrapper* wrapper) {
|
||||
return wrapper->get_proto();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -182,7 +182,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
int64 execution_count,
|
||||
SimpleGraphExecutionState* execution_state,
|
||||
PerStepState* pss, CallOptions* opts,
|
||||
const RunStepRequestWrapper& req, RunStepResponse* resp,
|
||||
const RunStepRequestWrapper& req,
|
||||
MutableRunStepResponseWrapper* resp,
|
||||
CancellationManager* cm, const bool is_last_partial_run);
|
||||
|
||||
// Calls workers to cleanup states for the step "step_id". Calls
|
||||
@ -193,7 +194,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
void ProcessStats(int64 step_id, PerStepState* pss,
|
||||
SimpleGraphExecutionState* execution_state,
|
||||
ProfileHandler* ph, const RunOptions& options,
|
||||
RunStepResponse* resp);
|
||||
RunMetadata* resp);
|
||||
void ProcessDeviceStats(ProfileHandler* ph,
|
||||
const SimpleGraphExecutionState* execution_state,
|
||||
const DeviceStepStats& ds, bool is_rpc);
|
||||
@ -464,7 +465,7 @@ class RunManyGraphs {
|
||||
struct Call {
|
||||
CallOptions opts;
|
||||
std::unique_ptr<MutableRunGraphRequestWrapper> req;
|
||||
RunGraphResponse resp;
|
||||
std::unique_ptr<MutableRunGraphResponseWrapper> resp;
|
||||
};
|
||||
Call* get(int index) { return &calls_[index]; }
|
||||
|
||||
@ -513,7 +514,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
const MasterEnv* env, int64 step_id, int64 execution_count,
|
||||
SimpleGraphExecutionState* execution_state, PerStepState* pss,
|
||||
CallOptions* call_opts, const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp, CancellationManager* cm,
|
||||
MutableRunStepResponseWrapper* resp, CancellationManager* cm,
|
||||
const bool is_last_partial_run) {
|
||||
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
|
||||
<< execution_count;
|
||||
@ -550,6 +551,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
const Part& part = partitions_[i];
|
||||
RunManyGraphs::Call* c = calls.get(i);
|
||||
c->req.reset(part.worker->CreateRunGraphRequest());
|
||||
c->resp.reset(part.worker->CreateRunGraphResponse());
|
||||
if (is_partial_) {
|
||||
c->req->set_is_partial(is_partial_);
|
||||
c->req->set_is_last_partial_run(is_last_partial_run);
|
||||
@ -614,7 +616,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
RunManyGraphs::Call* call = calls.get(i);
|
||||
TRACEPRINTF("Partition %d %s", i, part.name.c_str());
|
||||
part.worker->RunGraphAsync(
|
||||
&call->opts, call->req.get(), &call->resp,
|
||||
&call->opts, call->req.get(), call->resp.get(),
|
||||
std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
|
||||
}
|
||||
|
||||
@ -639,29 +641,28 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
if (status.ok()) {
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
for (auto& recv : *(calls.get(i)->resp.mutable_recv())) {
|
||||
auto* ret = resp->add_tensor();
|
||||
auto iter = part.key_fetch.find(recv.name());
|
||||
for (size_t j = 0; j < calls.get(i)->resp->num_recvs(); ++j) {
|
||||
auto iter = part.key_fetch.find(calls.get(i)->resp->recv_key(j));
|
||||
if (iter == part.key_fetch.end()) {
|
||||
status.Update(
|
||||
errors::Internal("Unexpected fetch key: ", recv.name()));
|
||||
status.Update(errors::Internal("Unexpected fetch key: ",
|
||||
calls.get(i)->resp->recv_key(j)));
|
||||
break;
|
||||
}
|
||||
const string& fetch = iter->second;
|
||||
ret->set_name(fetch);
|
||||
if (!CopyIfNeeded(recv.mutable_tensor(), ret->mutable_tensor())) {
|
||||
status.Update(
|
||||
errors::Internal("Unexpected unparseable tensor: ", recv.name()));
|
||||
status.Update(resp->AddTensorFromRunGraphResponse(
|
||||
fetch, calls.get(i)->resp.get(), j));
|
||||
if (!status.ok()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) {
|
||||
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
|
||||
if (pss->collect_timeline) {
|
||||
pss->step_stats[i].Swap(calls.get(i)->resp->mutable_step_stats());
|
||||
}
|
||||
if (pss->collect_costs && calls.get(i)->resp.has_cost_graph()) {
|
||||
for (int j = 0; j < calls.get(i)->resp.cost_graph().node_size(); ++j) {
|
||||
if (pss->collect_costs) {
|
||||
CostGraphDef* cost_graph = calls.get(i)->resp->mutable_cost_graph();
|
||||
for (int j = 0; j < cost_graph->node_size(); ++j) {
|
||||
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
|
||||
calls.get(i)->resp.mutable_cost_graph()->mutable_node(j));
|
||||
cost_graph->mutable_node(j));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -739,7 +740,7 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
|
||||
void MasterSession::ReffedClientGraph::ProcessStats(
|
||||
int64 step_id, PerStepState* pss,
|
||||
SimpleGraphExecutionState* execution_state, ProfileHandler* ph,
|
||||
const RunOptions& options, RunStepResponse* resp) {
|
||||
const RunOptions& options, RunMetadata* resp) {
|
||||
if (!pss->collect_costs && !pss->collect_timeline) return;
|
||||
|
||||
// Out-of-band logging data is collected now, during post-processing.
|
||||
@ -775,7 +776,7 @@ void MasterSession::ReffedClientGraph::ProcessStats(
|
||||
// Copy the stats back, but only for on-demand profiling to avoid slowing
|
||||
// down calls that trigger the automatic profiling.
|
||||
if (options.trace_level() == RunOptions::FULL_TRACE) {
|
||||
resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto);
|
||||
resp->mutable_step_stats()->Swap(&step_stats_proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1171,7 +1172,7 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
}
|
||||
|
||||
Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp) {
|
||||
MutableRunStepResponseWrapper* resp) {
|
||||
UpdateLastAccessTime();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -1239,7 +1240,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||
|
||||
Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp) {
|
||||
MutableRunStepResponseWrapper* resp) {
|
||||
const string& prun_handle = req.partial_run_handle();
|
||||
RunState* run_state = nullptr;
|
||||
{
|
||||
@ -1328,7 +1329,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
rcg->Ref();
|
||||
rcg->ProcessStats(run_state->step_id, &run_state->pss,
|
||||
execution_state_.get(), run_state->ph.get(),
|
||||
req.options(), resp);
|
||||
req.options(), resp->mutable_metadata());
|
||||
rcg->CleanupPartitionsAsync(
|
||||
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
@ -1342,9 +1343,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
return s;
|
||||
}
|
||||
|
||||
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp) {
|
||||
Status MasterSession::DoRunWithLocalExecution(
|
||||
CallOptions* opts, const RunStepRequestWrapper& req,
|
||||
MutableRunStepResponseWrapper* resp) {
|
||||
VLOG(2) << "DoRunWithLocalExecution "
|
||||
<< "req: " << req.DebugString();
|
||||
PerStepState pss;
|
||||
@ -1395,7 +1396,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
// Schedule post-processing and cleanup to be done asynchronously.
|
||||
rcg->Ref();
|
||||
rcg->ProcessStats(step_id, &pss, execution_state_.get(), ph.get(),
|
||||
req.options(), resp);
|
||||
req.options(), resp->mutable_metadata());
|
||||
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||
|
@ -80,7 +80,7 @@ class MasterSession : public core::RefCounted {
|
||||
|
||||
// Run one step.
|
||||
Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp);
|
||||
MutableRunStepResponseWrapper* resp);
|
||||
|
||||
// Close this session and delete "*this". Returns OK if all known
|
||||
// states are cleanup successfully.
|
||||
@ -177,9 +177,9 @@ class MasterSession : public core::RefCounted {
|
||||
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status DoRunWithLocalExecution(CallOptions* opts,
|
||||
const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp);
|
||||
MutableRunStepResponseWrapper* resp);
|
||||
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
|
||||
RunStepResponse* resp);
|
||||
MutableRunStepResponseWrapper* resp);
|
||||
void UpdateLastAccessTime();
|
||||
|
||||
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
|
||||
|
@ -17,6 +17,22 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
|
||||
Tensor* out_tensor) {
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
*out_tensor = parsed;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const string& InMemoryRunStepRequest::session_handle() const {
|
||||
return session_handle_;
|
||||
}
|
||||
@ -38,13 +54,14 @@ const string& InMemoryRunStepRequest::feed_name(size_t i) const {
|
||||
return feeds_[i].first;
|
||||
}
|
||||
|
||||
Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* tensor) const {
|
||||
*tensor = feeds_[i].second;
|
||||
Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
|
||||
*out_tensor = feeds_[i].second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InMemoryRunStepRequest::FeedValue(size_t i, TensorProto* tensor) const {
|
||||
feeds_[i].second.AsProtoTensorContent(tensor);
|
||||
Status InMemoryRunStepRequest::FeedValue(size_t i,
|
||||
TensorProto* out_tensor) const {
|
||||
feeds_[i].second.AsProtoTensorContent(out_tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -117,21 +134,18 @@ size_t MutableProtoRunStepRequest::num_feeds() const {
|
||||
const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
|
||||
return request_.feed(i).name();
|
||||
}
|
||||
Status MutableProtoRunStepRequest::FeedValue(size_t i, Tensor* tensor) const {
|
||||
const TensorProto& tensor_proto = request_.feed(i).tensor();
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
*tensor = parsed;
|
||||
return Status::OK();
|
||||
}
|
||||
Status MutableProtoRunStepRequest::FeedValue(size_t i,
|
||||
Tensor* out_tensor) const {
|
||||
if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
}
|
||||
|
||||
Status MutableProtoRunStepRequest::FeedValue(size_t i,
|
||||
TensorProto* tensor) const {
|
||||
*tensor = request_.feed(i).tensor();
|
||||
TensorProto* out_tensor) const {
|
||||
*out_tensor = request_.feed(i).tensor();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -199,20 +213,16 @@ const string& ProtoRunStepRequest::feed_name(size_t i) const {
|
||||
return request_->feed(i).name();
|
||||
}
|
||||
|
||||
Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* tensor) const {
|
||||
const TensorProto& tensor_proto = request_->feed(i).tensor();
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
*tensor = parsed;
|
||||
return Status::OK();
|
||||
}
|
||||
Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
|
||||
if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
}
|
||||
|
||||
Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* tensor) const {
|
||||
*tensor = request_->feed(i).tensor();
|
||||
Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
|
||||
*out_tensor = request_->feed(i).tensor();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -361,15 +371,11 @@ const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
|
||||
|
||||
Status MutableProtoRunGraphRequest::SendValue(size_t i,
|
||||
Tensor* out_tensor) const {
|
||||
const TensorProto& tensor_proto = request_.send(i).tensor();
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
*out_tensor = parsed;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
}
|
||||
|
||||
Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
|
||||
@ -434,15 +440,11 @@ const string& ProtoRunGraphRequest::send_key(size_t i) const {
|
||||
}
|
||||
|
||||
Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
|
||||
const TensorProto& tensor_proto = request_->send(i).tensor();
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
*out_tensor = parsed;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
|
||||
}
|
||||
|
||||
size_t ProtoRunGraphRequest::num_recvs() const {
|
||||
@ -463,4 +465,228 @@ const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
|
||||
return *request_;
|
||||
}
|
||||
|
||||
size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
|
||||
|
||||
const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
|
||||
return recvs_[i].first;
|
||||
}
|
||||
|
||||
Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
|
||||
recvs_[i].second.AsProtoTensorContent(out_tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
|
||||
*out_tensor = recvs_[i].second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
|
||||
recvs_.emplace_back(key, value);
|
||||
}
|
||||
|
||||
StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
|
||||
return &step_stats_;
|
||||
}
|
||||
|
||||
CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
|
||||
return &cost_graph_;
|
||||
}
|
||||
|
||||
RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
|
||||
LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
|
||||
}
|
||||
|
||||
size_t OwnedProtoRunGraphResponse::num_recvs() const {
|
||||
return response_.recv_size();
|
||||
}
|
||||
|
||||
const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
|
||||
return response_.recv(i).name();
|
||||
}
|
||||
|
||||
Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
|
||||
TensorProto* out_tensor) {
|
||||
out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
|
||||
if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
void OwnedProtoRunGraphResponse::AddRecv(const string& key,
|
||||
const Tensor& value) {
|
||||
NamedTensorProto* recv = response_.add_recv();
|
||||
recv->set_name(key);
|
||||
TensorProto* value_proto = recv->mutable_tensor();
|
||||
value.AsProtoTensorContent(value_proto);
|
||||
}
|
||||
|
||||
StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
|
||||
return response_.mutable_step_stats();
|
||||
}
|
||||
|
||||
CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
|
||||
return response_.mutable_cost_graph();
|
||||
}
|
||||
|
||||
RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
|
||||
|
||||
NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
|
||||
RunGraphResponse* response)
|
||||
: response_(response) {}
|
||||
|
||||
size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
|
||||
return response_->recv_size();
|
||||
}
|
||||
|
||||
const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
|
||||
return response_->recv(i).name();
|
||||
}
|
||||
|
||||
Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
|
||||
TensorProto* out_tensor) {
|
||||
out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
|
||||
if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
|
||||
const Tensor& value) {
|
||||
NamedTensorProto* recv = response_->add_recv();
|
||||
recv->set_name(key);
|
||||
TensorProto* value_proto = recv->mutable_tensor();
|
||||
value.AsProtoTensorContent(value_proto);
|
||||
}
|
||||
|
||||
StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
|
||||
return response_->mutable_step_stats();
|
||||
}
|
||||
|
||||
CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
|
||||
return response_->mutable_cost_graph();
|
||||
}
|
||||
|
||||
RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
|
||||
return response_;
|
||||
}
|
||||
|
||||
MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
|
||||
|
||||
size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
|
||||
|
||||
const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
|
||||
return tensors_[i].first;
|
||||
}
|
||||
|
||||
Status InMemoryRunStepResponse::TensorValue(size_t i,
|
||||
Tensor* out_tensor) const {
|
||||
*out_tensor = tensors_[i].second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const RunMetadata& InMemoryRunStepResponse::metadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
|
||||
tensors_.emplace_back(name, tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
|
||||
|
||||
RunStepResponse* InMemoryRunStepResponse::get_proto() {
|
||||
LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
|
||||
}
|
||||
|
||||
size_t OwnedProtoRunStepResponse::num_tensors() const {
|
||||
return response_.tensor_size();
|
||||
}
|
||||
|
||||
const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
|
||||
return response_.tensor(i).name();
|
||||
}
|
||||
|
||||
Status OwnedProtoRunStepResponse::TensorValue(size_t i,
|
||||
Tensor* out_tensor) const {
|
||||
if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
|
||||
return response_.metadata();
|
||||
}
|
||||
|
||||
Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) {
|
||||
NamedTensorProto* response_tensor = response_.add_tensor();
|
||||
response_tensor->set_name(name);
|
||||
return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
|
||||
}
|
||||
|
||||
RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
|
||||
return response_.mutable_metadata();
|
||||
}
|
||||
|
||||
RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
|
||||
|
||||
NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
|
||||
RunStepResponse* response)
|
||||
: response_(response) {}
|
||||
|
||||
size_t NonOwnedProtoRunStepResponse::num_tensors() const {
|
||||
return response_->tensor_size();
|
||||
}
|
||||
|
||||
const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
|
||||
return response_->tensor(i).name();
|
||||
}
|
||||
|
||||
Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
|
||||
Tensor* out_tensor) const {
|
||||
if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
|
||||
return response_->metadata();
|
||||
}
|
||||
|
||||
Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) {
|
||||
NamedTensorProto* response_tensor = response_->add_tensor();
|
||||
response_tensor->set_name(name);
|
||||
return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
|
||||
}
|
||||
|
||||
RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
|
||||
return response_->mutable_metadata();
|
||||
}
|
||||
|
||||
RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -58,8 +58,8 @@ class RunStepRequestWrapper {
|
||||
virtual const string& feed_name(size_t i) const = 0;
|
||||
|
||||
// Stores the content of the feed value at index `i` in `tensor`.
|
||||
virtual Status FeedValue(size_t i, Tensor* tensor) const = 0;
|
||||
virtual Status FeedValue(size_t i, TensorProto* tensor) const = 0;
|
||||
virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
|
||||
virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
|
||||
|
||||
// Fetches. A list of tensor names. The caller expects a tensor to
|
||||
// be returned for each fetch[i] (see RunStepResponse.tensor). The
|
||||
@ -104,8 +104,8 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
|
||||
const string& partial_run_handle() const override;
|
||||
size_t num_feeds() const override;
|
||||
const string& feed_name(size_t i) const override;
|
||||
Status FeedValue(size_t i, Tensor* tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* tensor) const override;
|
||||
Status FeedValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* out_tensor) const override;
|
||||
size_t num_fetches() const override;
|
||||
const string& fetch_name(size_t i) const override;
|
||||
size_t num_targets() const override;
|
||||
@ -151,8 +151,8 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
|
||||
const string& partial_run_handle() const override;
|
||||
size_t num_feeds() const override;
|
||||
const string& feed_name(size_t i) const override;
|
||||
Status FeedValue(size_t i, Tensor* tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* tensor) const override;
|
||||
Status FeedValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* out_tensor) const override;
|
||||
size_t num_fetches() const override;
|
||||
const string& fetch_name(size_t i) const override;
|
||||
size_t num_targets() const override;
|
||||
@ -188,8 +188,8 @@ class ProtoRunStepRequest : public RunStepRequestWrapper {
|
||||
const string& partial_run_handle() const override;
|
||||
size_t num_feeds() const override;
|
||||
const string& feed_name(size_t i) const override;
|
||||
Status FeedValue(size_t i, Tensor* tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* tensor) const override;
|
||||
Status FeedValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status FeedValue(size_t i, TensorProto* out_tensor) const override;
|
||||
size_t num_fetches() const override;
|
||||
const string& fetch_name(size_t i) const override;
|
||||
size_t num_targets() const override;
|
||||
@ -373,6 +373,242 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
|
||||
const RunGraphRequest* const request_; // Not owned.
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Wrapper classes for the `WorkerService.RunGraph` response message.
|
||||
//
|
||||
// The `RunGraphResponse` message can contain potentially large tensor
|
||||
// data as part of its `recv` submessages. Here we provide specialized
|
||||
// wrappers that avoid copying the tensor data wherever possible.
|
||||
//
|
||||
// See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
|
||||
// protocol buffer definition.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Abstract interface for a mutable RunGraphResponse message.
|
||||
//
|
||||
// Note that there is no corresponding (immutable)
|
||||
// RunGraphResponseWrapper class, because the RunGraphResponse object
|
||||
// is always used as a mutable pointer.
|
||||
class MutableRunGraphResponseWrapper {
|
||||
public:
|
||||
virtual ~MutableRunGraphResponseWrapper() {}
|
||||
|
||||
// A list of tensors corresponding to those requested by
|
||||
// `RunGraphRequest.recv_key`.
|
||||
virtual size_t num_recvs() const = 0;
|
||||
virtual const string& recv_key(size_t i) const = 0;
|
||||
// NOTE: The following methods may perform a destructive read, for
|
||||
// efficiency.
|
||||
virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
|
||||
virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
|
||||
virtual void AddRecv(const string& key, const Tensor& value) = 0;
|
||||
|
||||
// Submessages that store performance statistics about the subgraph
|
||||
// execution, if necessary.
|
||||
virtual StepStats* mutable_step_stats() = 0;
|
||||
virtual CostGraphDef* mutable_cost_graph() = 0;
|
||||
|
||||
protected:
|
||||
// Returns a mutable protobuf message that represents the contents of
|
||||
// this wrapper, for passing to an RPC subsystem that will populate
|
||||
// the message.
|
||||
//
|
||||
// NOTE: Only `WorkerInterface` subclasses may call this method. The
|
||||
// `InMemoryRunGraphResponse` subclass does not implement this
|
||||
// method, and attempts to call it will fail with a fatal
|
||||
// error. However, as long as callers always call
|
||||
// `WorkerInterface::RunGraphAsync()` with a wrapper object returned
|
||||
// from `WorkerInterface::CreateRunGraphResponse()` called on the
|
||||
// *same* WorkerInterface object, this error will never trigger.
|
||||
virtual RunGraphResponse* get_proto() = 0;
|
||||
friend class WorkerInterface;
|
||||
};
|
||||
|
||||
class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
|
||||
public:
|
||||
// MutableRunGraphResponseWrapper methods.
|
||||
size_t num_recvs() const override;
|
||||
const string& recv_key(size_t i) const override;
|
||||
Status RecvValue(size_t i, TensorProto* out_tensor) override;
|
||||
Status RecvValue(size_t i, Tensor* out_tensor) override;
|
||||
void AddRecv(const string& key, const Tensor& value) override;
|
||||
StepStats* mutable_step_stats() override;
|
||||
CostGraphDef* mutable_cost_graph() override;
|
||||
|
||||
protected:
|
||||
// NOTE: This method is not implemented. See
|
||||
// MutableRunGraphResponseWrapper for an explanation.
|
||||
RunGraphResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
|
||||
StepStats step_stats_;
|
||||
CostGraphDef cost_graph_;
|
||||
};
|
||||
|
||||
// Proto-based message wrapper for use on the client side of the RunGraph RPC.
|
||||
class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
|
||||
public:
|
||||
// MutableRunGraphResponseWrapper methods.
|
||||
size_t num_recvs() const override;
|
||||
const string& recv_key(size_t i) const override;
|
||||
Status RecvValue(size_t i, TensorProto* out_tensor) override;
|
||||
Status RecvValue(size_t i, Tensor* out_tensor) override;
|
||||
void AddRecv(const string& key, const Tensor& value) override;
|
||||
StepStats* mutable_step_stats() override;
|
||||
CostGraphDef* mutable_cost_graph() override;
|
||||
|
||||
protected:
|
||||
RunGraphResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
RunGraphResponse response_;
|
||||
};
|
||||
|
||||
// Proto-based message wrapper for use on the server side of the RunGraph RPC.
|
||||
class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
|
||||
public:
|
||||
NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
|
||||
|
||||
// MutableRunGraphResponseWrapper methods.
|
||||
size_t num_recvs() const override;
|
||||
const string& recv_key(size_t i) const override;
|
||||
Status RecvValue(size_t i, TensorProto* out_tensor) override;
|
||||
Status RecvValue(size_t i, Tensor* out_tensor) override;
|
||||
void AddRecv(const string& key, const Tensor& value) override;
|
||||
StepStats* mutable_step_stats() override;
|
||||
CostGraphDef* mutable_cost_graph() override;
|
||||
|
||||
protected:
|
||||
RunGraphResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
RunGraphResponse* const response_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Wrapper classes for the `MasterService.RunStep` response message.
|
||||
//
|
||||
// The `RunStepResponse` message can contain potentially large tensor
|
||||
// data as part of its `tensor` submessages. Here we provide specialized
|
||||
// wrappers that avoid copying the tensor data wherever possible.
|
||||
//
|
||||
// See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
|
||||
// protocol buffer definition.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Abstract interface for a mutable RunStepResponse message.
|
||||
//
|
||||
// Note that there is no corresponding (immutable)
|
||||
// RunStepResponseWrapper class, because the RunStepResponse object is
|
||||
// always used as a mutable pointer.
|
||||
class MutableRunStepResponseWrapper {
|
||||
public:
|
||||
virtual ~MutableRunStepResponseWrapper();
|
||||
|
||||
// The values of the tensors whose fetching was requested in the
|
||||
// RunStep call.
|
||||
//
|
||||
// NOTE: The order of the returned tensors may or may not match
|
||||
// the fetch order specified in RunStepRequest.
|
||||
virtual size_t num_tensors() const = 0;
|
||||
virtual const string& tensor_name(size_t i) const = 0;
|
||||
virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
|
||||
|
||||
// Stores the i^{th} recv value in `run_graph_response` in this
|
||||
// response with the given `name`.
|
||||
virtual Status AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) = 0;
|
||||
|
||||
// Returned metadata if requested in the options.
|
||||
virtual const RunMetadata& metadata() const = 0;
|
||||
virtual RunMetadata* mutable_metadata() = 0;
|
||||
|
||||
protected:
|
||||
// Returns a mutable protobuf message that represents the contents of
|
||||
// this wrapper, for passing to an RPC subsystem that will populate
|
||||
// the message.
|
||||
//
|
||||
// NOTE: Only `MasterInterface` subclasses may call this method. The
|
||||
// `InMemoryRunStepResponse` subclass does not implement this
|
||||
// method, and attempts to call it will fail with a fatal
|
||||
// error. However, as long as callers always call
|
||||
// `MasterInterface::RunStep()` with a wrapper object returned
|
||||
// from `MasterInterface::CreateRunStepResponse()` called on the
|
||||
// *same* MasterInterface object, this error will never trigger.
|
||||
virtual RunStepResponse* get_proto() = 0;
|
||||
friend class MasterInterface;
|
||||
};
|
||||
|
||||
class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
|
||||
public:
|
||||
// MutableRunStepResponseWrapper methods.
|
||||
size_t num_tensors() const override;
|
||||
const string& tensor_name(size_t i) const override;
|
||||
Status TensorValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) override;
|
||||
const RunMetadata& metadata() const override;
|
||||
RunMetadata* mutable_metadata() override;
|
||||
|
||||
protected:
|
||||
// NOTE: This method is not implemented. See
|
||||
// MutableRunGraphResponseWrapper for an explanation.
|
||||
RunStepResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
|
||||
RunMetadata metadata_;
|
||||
};
|
||||
|
||||
// Proto-based message wrapper for use on the client side of the RunStep RPC.
|
||||
class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
|
||||
public:
|
||||
// MutableRunStepResponseWrapper methods.
|
||||
size_t num_tensors() const override;
|
||||
const string& tensor_name(size_t i) const override;
|
||||
Status TensorValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) override;
|
||||
const RunMetadata& metadata() const override;
|
||||
RunMetadata* mutable_metadata() override;
|
||||
|
||||
protected:
|
||||
RunStepResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
RunStepResponse response_;
|
||||
};
|
||||
|
||||
// Proto-based message wrapper for use on the server side of the RunStep RPC.
|
||||
class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
|
||||
public:
|
||||
NonOwnedProtoRunStepResponse(RunStepResponse* response);
|
||||
|
||||
// MutableRunStepResponseWrapper methods.
|
||||
size_t num_tensors() const override;
|
||||
const string& tensor_name(size_t i) const override;
|
||||
Status TensorValue(size_t i, Tensor* out_tensor) const override;
|
||||
Status AddTensorFromRunGraphResponse(
|
||||
const string& name, MutableRunGraphResponseWrapper* run_graph_response,
|
||||
size_t i) override;
|
||||
const RunMetadata& metadata() const override;
|
||||
RunMetadata* mutable_metadata() override;
|
||||
|
||||
protected:
|
||||
RunStepResponse* get_proto() override;
|
||||
|
||||
private:
|
||||
RunStepResponse* response_; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW
|
||||
|
@ -92,6 +92,55 @@ static void CheckRunGraphRequest(const RunGraphRequestWrapper& request) {
|
||||
EXPECT_FALSE(request.is_last_partial_run());
|
||||
}
|
||||
|
||||
static void BuildRunGraphResponse(
|
||||
MutableRunGraphResponseWrapper* run_graph_response) {
|
||||
run_graph_response->AddRecv("recv_2", TensorA());
|
||||
run_graph_response->AddRecv("recv_3", TensorB());
|
||||
run_graph_response->mutable_step_stats()->add_dev_stats()->set_device(
|
||||
"/cpu:0");
|
||||
run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node");
|
||||
}
|
||||
|
||||
static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) {
|
||||
EXPECT_EQ(2, response->num_recvs());
|
||||
EXPECT_EQ("recv_2", response->recv_key(0));
|
||||
EXPECT_EQ("recv_3", response->recv_key(1));
|
||||
Tensor val;
|
||||
response->RecvValue(0, &val);
|
||||
test::ExpectTensorEqual<int32>(TensorA(), val);
|
||||
response->RecvValue(1, &val);
|
||||
test::ExpectTensorEqual<int32>(TensorB(), val);
|
||||
EXPECT_EQ(1, response->mutable_step_stats()->dev_stats_size());
|
||||
EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device());
|
||||
EXPECT_EQ(1, response->mutable_cost_graph()->node_size());
|
||||
EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name());
|
||||
}
|
||||
|
||||
static void BuildRunStepResponse(
|
||||
MutableRunGraphResponseWrapper* run_graph_response,
|
||||
MutableRunStepResponseWrapper* run_step_response) {
|
||||
run_step_response->AddTensorFromRunGraphResponse("fetch_x:0",
|
||||
run_graph_response, 0);
|
||||
run_step_response->AddTensorFromRunGraphResponse("fetch_y:0",
|
||||
run_graph_response, 1);
|
||||
*run_step_response->mutable_metadata()->mutable_step_stats() =
|
||||
*run_graph_response->mutable_step_stats();
|
||||
}
|
||||
|
||||
static void CheckRunStepResponse(
|
||||
const MutableRunStepResponseWrapper& response) {
|
||||
EXPECT_EQ(2, response.num_tensors());
|
||||
EXPECT_EQ("fetch_x:0", response.tensor_name(0));
|
||||
EXPECT_EQ("fetch_y:0", response.tensor_name(1));
|
||||
Tensor val;
|
||||
response.TensorValue(0, &val);
|
||||
test::ExpectTensorEqual<int32>(TensorA(), val);
|
||||
response.TensorValue(1, &val);
|
||||
test::ExpectTensorEqual<int32>(TensorB(), val);
|
||||
EXPECT_EQ(1, response.metadata().step_stats().dev_stats_size());
|
||||
EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device());
|
||||
}
|
||||
|
||||
TEST(MessageWrappers, RunStepRequest_Basic) {
|
||||
InMemoryRunStepRequest in_memory_request;
|
||||
BuildRunStepRequest(&in_memory_request);
|
||||
@ -164,4 +213,108 @@ TEST(MessageWrappers, RunGraphRequest_Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MessageWrappers, RunGraphResponse_Basic) {
|
||||
InMemoryRunGraphResponse in_memory_response;
|
||||
BuildRunGraphResponse(&in_memory_response);
|
||||
CheckRunGraphResponse(&in_memory_response);
|
||||
|
||||
OwnedProtoRunGraphResponse owned_proto_response;
|
||||
BuildRunGraphResponse(&owned_proto_response);
|
||||
CheckRunGraphResponse(&owned_proto_response);
|
||||
|
||||
RunGraphResponse response_proto;
|
||||
NonOwnedProtoRunGraphResponse non_owned_proto_response(&response_proto);
|
||||
BuildRunGraphResponse(&non_owned_proto_response);
|
||||
CheckRunGraphResponse(&non_owned_proto_response);
|
||||
}
|
||||
|
||||
TEST(MessageWrappers, RunStepResponse_Basic) {
|
||||
{
|
||||
// Worker -(in memory)-> Master -(in memory)-> Client.
|
||||
InMemoryRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
InMemoryRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(in memory)-> Master -(owned proto)-> Client.
|
||||
InMemoryRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
OwnedProtoRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(in memory)-> Master -(non-owned proto)-> Client.
|
||||
InMemoryRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
RunStepResponse response_proto;
|
||||
NonOwnedProtoRunStepResponse response(&response_proto);
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(owned proto)-> Master -(in memory)-> Client.
|
||||
OwnedProtoRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
InMemoryRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(owned proto)-> Master -(owned proto)-> Client.
|
||||
OwnedProtoRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
OwnedProtoRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(owned proto)-> Master -(non-owned proto)-> Client.
|
||||
OwnedProtoRunGraphResponse run_graph_response;
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
RunStepResponse response_proto;
|
||||
NonOwnedProtoRunStepResponse response(&response_proto);
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(non-owned proto)-> Master -(in memory)-> Client.
|
||||
RunGraphResponse run_graph_response_proto;
|
||||
NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
InMemoryRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(non-owned proto)-> Master -(owned proto)-> Client.
|
||||
RunGraphResponse run_graph_response_proto;
|
||||
NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
OwnedProtoRunStepResponse response;
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
|
||||
{
|
||||
// Worker -(non-owned proto)-> Master -(non-owned proto)-> Client.
|
||||
RunGraphResponse run_graph_response_proto;
|
||||
NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
|
||||
BuildRunGraphResponse(&run_graph_response);
|
||||
RunStepResponse response_proto;
|
||||
NonOwnedProtoRunStepResponse response(&response_proto);
|
||||
BuildRunStepResponse(&run_graph_response, &response);
|
||||
CheckRunStepResponse(response);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -174,15 +174,17 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
RunStepRequestWrapper* wrapped_request =
|
||||
new ProtoRunStepRequest(&call->request);
|
||||
MutableRunStepResponseWrapper* wrapped_response =
|
||||
new NonOwnedProtoRunStepResponse(&call->response);
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
master_impl_->RunStep(
|
||||
call_opts, wrapped_request, &call->response,
|
||||
[call, call_opts, wrapped_request](const Status& status) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
master_impl_->RunStep(call_opts, wrapped_request, wrapped_response,
|
||||
[call, call_opts, wrapped_request,
|
||||
wrapped_response](const Status& status) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(RunStep, true);
|
||||
}
|
||||
|
||||
|
@ -62,11 +62,12 @@ class GrpcRemoteMaster : public MasterInterface {
|
||||
}
|
||||
|
||||
Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
|
||||
RunStepResponse* response) override {
|
||||
MutableRunStepResponseWrapper* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
SetDeadline(&ctx, call_options->GetTimeout());
|
||||
return FromGrpcStatus(stub_->RunStep(&ctx, request->ToProto(), response));
|
||||
return FromGrpcStatus(stub_->RunStep(&ctx, request->ToProto(),
|
||||
get_proto_from_wrapper(response)));
|
||||
}
|
||||
|
||||
Status CloseSession(CallOptions* call_options,
|
||||
|
@ -75,9 +75,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
IssueRequest(request, response, rungraph_, std::move(done), call_opts);
|
||||
}
|
||||
void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done) override {
|
||||
IssueRequest(&request->ToProto(), response, rungraph_, std::move(done),
|
||||
call_opts);
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
|
||||
rungraph_, std::move(done), call_opts);
|
||||
}
|
||||
|
||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
|
@ -171,7 +171,8 @@ Status GrpcSession::RunHelper(
|
||||
// Convert to proto
|
||||
std::unique_ptr<MutableRunStepRequestWrapper> req(
|
||||
master_->CreateRunStepRequest());
|
||||
RunStepResponse resp;
|
||||
std::unique_ptr<MutableRunStepResponseWrapper> resp(
|
||||
master_->CreateRunStepResponse());
|
||||
|
||||
*req->mutable_options() = run_options;
|
||||
|
||||
@ -196,31 +197,27 @@ Status GrpcSession::RunHelper(
|
||||
|
||||
CallOptions call_options;
|
||||
call_options.SetTimeout(run_options.timeout_in_ms());
|
||||
TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), &resp));
|
||||
TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
|
||||
|
||||
if (!output_tensor_names.empty()) {
|
||||
outputs->resize(output_tensor_names.size());
|
||||
}
|
||||
|
||||
// Convert response back to Tensors in the correct order.
|
||||
for (const NamedTensorProto& tensor : resp.tensor()) {
|
||||
auto fetch_it = output_name_to_offset.find(tensor.name());
|
||||
for (size_t i = 0; i < resp->num_tensors(); ++i) {
|
||||
auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
|
||||
if (fetch_it == output_name_to_offset.end()) {
|
||||
return errors::Internal("Received response for unrequested fetch: ",
|
||||
tensor.name());
|
||||
resp->tensor_name(i));
|
||||
}
|
||||
|
||||
Tensor output;
|
||||
if (!output.FromProto(tensor.tensor())) {
|
||||
return errors::InvalidArgument("Could not parse returned proto for ",
|
||||
tensor.name());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
|
||||
(*outputs)[fetch_it->second] = output;
|
||||
}
|
||||
|
||||
if (run_metadata) {
|
||||
run_metadata->Swap(resp.mutable_metadata());
|
||||
run_metadata->Swap(resp->mutable_metadata());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -248,7 +245,7 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
|
||||
Status GrpcSession::RunProto(CallOptions* call_options,
|
||||
MutableRunStepRequestWrapper* req,
|
||||
RunStepResponse* resp) {
|
||||
MutableRunStepResponseWrapper* resp) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (handle_.empty()) {
|
||||
|
@ -119,7 +119,7 @@ class GrpcSession : public Session {
|
||||
const string& prun_handle);
|
||||
|
||||
Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req,
|
||||
RunStepResponse* resp);
|
||||
MutableRunStepResponseWrapper* resp);
|
||||
|
||||
// Implementations for all the public interfaces.
|
||||
Status CreateImpl(CallOptions* call_options, const GraphDef& graph);
|
||||
|
@ -215,15 +215,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
ProtoRunGraphRequest* wrapped_request =
|
||||
new ProtoRunGraphRequest(&call->request);
|
||||
NonOwnedProtoRunGraphResponse* wrapped_response =
|
||||
new NonOwnedProtoRunGraphResponse(&call->response);
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
worker_->RunGraphAsync(
|
||||
call_opts, wrapped_request, &call->response,
|
||||
[call, call_opts, wrapped_request](const Status& s) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
||||
[call, call_opts, wrapped_request,
|
||||
wrapped_response](const Status& s) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
delete wrapped_response;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
});
|
||||
ENQUEUE_REQUEST(RunGraph, true);
|
||||
}
|
||||
|
@ -110,7 +110,8 @@ Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
|
||||
}
|
||||
|
||||
void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done) {
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done) {
|
||||
if (request->is_partial()) {
|
||||
DoPartialRunGraph(opts, request, response, std::move(done));
|
||||
} else {
|
||||
@ -122,8 +123,13 @@ MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
|
||||
return new InMemoryRunGraphRequest;
|
||||
}
|
||||
|
||||
MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
|
||||
return new InMemoryRunGraphResponse;
|
||||
}
|
||||
|
||||
void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done) {
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done) {
|
||||
const int64 step_id = request->step_id();
|
||||
TRACEPRINTF("RunGraph: %lld", step_id);
|
||||
GraphMgr::NamedTensors in;
|
||||
@ -179,11 +185,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
for (const auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
auto* recv = response->add_recv();
|
||||
recv->set_name(key);
|
||||
// TODO(zhifengc): Deal with gpu -> cpu copy.
|
||||
TensorProto* proto = recv->mutable_tensor();
|
||||
val.AsProtoTensorContent(proto);
|
||||
response->AddRecv(key, val);
|
||||
}
|
||||
}
|
||||
delete collector;
|
||||
@ -195,7 +197,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
// TODO(suharshs): Add stats collection support to partial run.
|
||||
void Worker::DoPartialRunGraph(CallOptions* opts,
|
||||
RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response,
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done) {
|
||||
const int64 step_id = request->step_id();
|
||||
const string& graph_handle = request->graph_handle();
|
||||
@ -276,11 +278,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
|
||||
for (const auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
auto* recv = response->add_recv();
|
||||
recv->set_name(key);
|
||||
// TODO(zhifengc): Deal with gpu -> cpu copy.
|
||||
TensorProto* proto = recv->mutable_tensor();
|
||||
val.AsProtoTensorContent(proto);
|
||||
response->AddRecv(key, val);
|
||||
}
|
||||
|
||||
// If this is the last partial run request we must also wait for the entire
|
||||
|
@ -54,10 +54,13 @@ class Worker : public WorkerInterface {
|
||||
StatusCallback done) override;
|
||||
|
||||
void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done) override;
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
MutableRunGraphRequestWrapper* CreateRunGraphRequest() override;
|
||||
|
||||
MutableRunGraphResponseWrapper* CreateRunGraphResponse() override;
|
||||
|
||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) override;
|
||||
@ -117,10 +120,12 @@ class Worker : public WorkerInterface {
|
||||
GraphMgr::NamedTensors* out);
|
||||
|
||||
void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done);
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done);
|
||||
|
||||
void DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* response, StatusCallback done);
|
||||
MutableRunGraphResponseWrapper* response,
|
||||
StatusCallback done);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Worker);
|
||||
};
|
||||
|
@ -49,7 +49,7 @@ class WorkerInterface {
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
RunGraphResponse* repsonse,
|
||||
MutableRunGraphResponseWrapper* repsonse,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
|
||||
@ -57,17 +57,34 @@ class WorkerInterface {
|
||||
// TODO(mrry): Convert this to std::bind/std::move if the overhead
|
||||
// of std::function copying becomes too much.
|
||||
RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
|
||||
RunGraphAsync(opts, wrapped_request, response,
|
||||
[wrapped_request, done](const Status& s) {
|
||||
MutableRunGraphResponseWrapper* wrapped_response =
|
||||
new NonOwnedProtoRunGraphResponse(response);
|
||||
RunGraphAsync(opts, wrapped_request, wrapped_response,
|
||||
[wrapped_request, wrapped_response, done](const Status& s) {
|
||||
done(s);
|
||||
delete wrapped_request;
|
||||
delete wrapped_response;
|
||||
});
|
||||
}
|
||||
|
||||
// Returns a request object for use in calls to
|
||||
// `RunGraphAsync()`. Ownership is transferred to the caller.
|
||||
//
|
||||
// The message returned from this method must only be used in a
|
||||
// `RunGraph()` call on the same `WorkerInterface` instance.
|
||||
virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
|
||||
return new MutableProtoRunGraphRequest;
|
||||
}
|
||||
|
||||
// Returns a response object for use in calls to
|
||||
// `RunGraphAsync()`. Ownership is transferred to the caller.
|
||||
//
|
||||
// The message returned from this method must only be used in a
|
||||
// `RunGraph()` call on the same `WorkerInterface` instance.
|
||||
virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
|
||||
return new OwnedProtoRunGraphResponse;
|
||||
}
|
||||
|
||||
virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
@ -126,6 +143,14 @@ class WorkerInterface {
|
||||
virtual ~WorkerInterface() {}
|
||||
friend class WorkerCacheInterface;
|
||||
|
||||
// NOTE: This should only be called by implementations of this
|
||||
// interface whose CreateRunGraphResponse() method returns a
|
||||
// proto-based wrappers for the RunGraphResponse message.
|
||||
RunGraphResponse* get_proto_from_wrapper(
|
||||
MutableRunGraphResponseWrapper* wrapper) {
|
||||
return wrapper->get_proto();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef WorkerInterface ME;
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -80,8 +81,7 @@ InferenceContext::InferenceContext(
|
||||
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
||||
}
|
||||
|
||||
InferenceContext::~InferenceContext() {
|
||||
}
|
||||
InferenceContext::~InferenceContext() {}
|
||||
|
||||
Status InferenceContext::set_output(StringPiece output_name,
|
||||
const std::vector<ShapeHandle>& shapes) {
|
||||
@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) {
|
||||
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
|
||||
}
|
||||
|
||||
string InferenceContext::DebugString() const {
|
||||
return strings::StrCat("InferenceContext for node: ",
|
||||
ProtoDebugString(node_def_));
|
||||
}
|
||||
|
||||
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
|
||||
ShapeHandle* out) {
|
||||
const int32 existing = Rank(shape);
|
||||
|
@ -259,6 +259,9 @@ class InferenceContext {
|
||||
string DebugString(ShapeHandle s);
|
||||
string DebugString(DimensionHandle d);
|
||||
|
||||
// Describes the whole context, for debugging purposes.
|
||||
string DebugString() const;
|
||||
|
||||
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
||||
// the shape with asserted rank in <*out>. Otherwise return an error.
|
||||
//
|
||||
|
@ -66,6 +66,7 @@ class GraphConstructor {
|
||||
? in.prefix
|
||||
: in.prefix + "/"),
|
||||
input_map(in.input_map),
|
||||
control_dependencies(in.control_dependencies),
|
||||
importing(true) {}
|
||||
|
||||
bool allow_internal_ops;
|
||||
@ -73,6 +74,7 @@ class GraphConstructor {
|
||||
|
||||
string prefix;
|
||||
std::map<TensorId, TensorId> input_map;
|
||||
std::vector<string> control_dependencies;
|
||||
|
||||
// TODO(ashankar): This bool exists to separate out functionality required
|
||||
// to make ImportGraphDef a close equivalent of Python's import_graph_def
|
||||
@ -107,7 +109,7 @@ class GraphConstructor {
|
||||
|
||||
Status TryImport() {
|
||||
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
|
||||
TF_RETURN_IF_ERROR(ValidateInputMap());
|
||||
TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
|
||||
TF_RETURN_IF_ERROR(BuildNodeIndex());
|
||||
TF_RETURN_IF_ERROR(InitFromEdges());
|
||||
TF_RETURN_IF_ERROR(Convert());
|
||||
@ -118,7 +120,7 @@ class GraphConstructor {
|
||||
}
|
||||
|
||||
Status EnsureNoNameCollisions();
|
||||
Status ValidateInputMap();
|
||||
Status ValidateInputMapAndControlDependencies();
|
||||
Status BuildNodeIndex();
|
||||
Status InitFromEdges();
|
||||
Status Convert();
|
||||
@ -132,11 +134,18 @@ class GraphConstructor {
|
||||
Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
|
||||
Status ValidateShape(Node* node);
|
||||
Status ModifyNodeDefForImport(NodeDef* node_def);
|
||||
// Modifies node_def's inputs according to opts_.input_map. input_remapped is
|
||||
// a pre-initialized vector of length node_def->input_size() indicating
|
||||
// whether each input has been remapped.
|
||||
void RemapNodeDefInputs(NodeDef* node_def, std::vector<bool>* input_remapped);
|
||||
void AddPrefixToNodeDef(const std::vector<bool>& input_remapped,
|
||||
// Modifies node_def's inputs according to opts_.input_map.
|
||||
// input_already_exists is a pre-initialized vector of length
|
||||
// node_def->input_size(). This function will mark inputs that are remapped to
|
||||
// true.
|
||||
void RemapNodeDefInputs(NodeDef* node_def,
|
||||
std::vector<bool>* input_already_exists);
|
||||
// input_already_exists is a pre-initialized vector of length
|
||||
// node_def->input_size(). This function will add and mark control inputs as
|
||||
// true.
|
||||
void AddControlDependencies(NodeDef* node_def,
|
||||
std::vector<bool>* input_already_exists);
|
||||
void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
|
||||
NodeDef* node_def);
|
||||
|
||||
// From constructor
|
||||
@ -208,14 +217,27 @@ bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool NodeNameInValues(const std::vector<string>& control_dependencies,
|
||||
const StringPiece& node_name) {
|
||||
return std::find(control_dependencies.begin(), control_dependencies.end(),
|
||||
node_name) != control_dependencies.end();
|
||||
}
|
||||
|
||||
Status GraphConstructor::EnsureNoNameCollisions() {
|
||||
existing_nodes_.reserve(g_->num_nodes());
|
||||
for (Node* n : g_->nodes()) {
|
||||
bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
|
||||
if (already_exists && NodeNameInValues(opts_.input_map, n->name())) {
|
||||
return errors::InvalidArgument(
|
||||
"cannot resolve input_map because multiple nodes exist with name '",
|
||||
n->name(), "'");
|
||||
if (already_exists) {
|
||||
if (NodeNameInValues(opts_.input_map, n->name())) {
|
||||
return errors::InvalidArgument(
|
||||
"cannot resolve input_map because multiple nodes exist with name '",
|
||||
n->name(), "'");
|
||||
}
|
||||
if (NodeNameInValues(opts_.control_dependencies, n->name())) {
|
||||
return errors::InvalidArgument(
|
||||
"cannot resolve control_dependencies because multiple nodes exist "
|
||||
"with name '", n->name(), "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (opts_.prefix.empty() && opts_.importing) {
|
||||
@ -248,7 +270,7 @@ Status GraphConstructor::EnsureNoNameCollisions() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphConstructor::ValidateInputMap() {
|
||||
Status GraphConstructor::ValidateInputMapAndControlDependencies() {
|
||||
for (const auto& mapping : opts_.input_map) {
|
||||
TensorId src = mapping.first;
|
||||
TensorId dst = mapping.second;
|
||||
@ -264,6 +286,13 @@ Status GraphConstructor::ValidateInputMap() {
|
||||
"control edge and non-control edge");
|
||||
}
|
||||
}
|
||||
for (const string& node : opts_.control_dependencies) {
|
||||
if (existing_nodes_.count(node) == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"node '", node, "' in control_dependencies does not exist in "
|
||||
"graph");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -466,9 +495,9 @@ void RemoveInputs(NodeDef* node_def, const std::vector<int>& inputs_to_remove) {
|
||||
}
|
||||
}
|
||||
|
||||
void GraphConstructor::RemapNodeDefInputs(NodeDef* node_def,
|
||||
std::vector<bool>* input_remapped) {
|
||||
DCHECK_EQ(input_remapped->size(), node_def->input_size());
|
||||
void GraphConstructor::RemapNodeDefInputs(
|
||||
NodeDef* node_def, std::vector<bool>* input_already_exists) {
|
||||
DCHECK_EQ(input_already_exists->size(), node_def->input_size());
|
||||
std::set<TensorId> control_inputs;
|
||||
std::vector<int> inputs_to_remove;
|
||||
|
||||
@ -487,13 +516,52 @@ void GraphConstructor::RemapNodeDefInputs(NodeDef* node_def,
|
||||
control_inputs.insert(new_input);
|
||||
}
|
||||
node_def->set_input(i, new_input.ToString());
|
||||
(*input_remapped)[i] = true;
|
||||
(*input_already_exists)[i] = true;
|
||||
}
|
||||
if (!inputs_to_remove.empty()) RemoveInputs(node_def, inputs_to_remove);
|
||||
}
|
||||
|
||||
void GraphConstructor::AddControlDependencies(
|
||||
NodeDef* node_def, std::vector<bool>* input_already_exists) {
|
||||
// To avoid adding redundant control dependencies to every imported node, skip
|
||||
// nodes that will inherit the dependencies from another imported node.
|
||||
bool inherits_deps = false;
|
||||
for (int i = 0; i < node_def->input_size(); ++i) {
|
||||
// Assume we won't inherit dependencies from remapped inputs that already
|
||||
// exist in the graph. Even if we're wrong, we'll only add redundant
|
||||
// dependencies.
|
||||
if ((*input_already_exists)[i]) continue;
|
||||
|
||||
// If this input is a backedge, assume we won't inherit the dependencies.
|
||||
// TODO(skyewm): we have many redundant ParseTensorName calls. It could be
|
||||
// worth optimizing these.
|
||||
TensorId id(ParseTensorName(node_def->input(i)));
|
||||
auto iter = gdef_nodes_.find(id.first);
|
||||
DCHECK(iter != gdef_nodes_.end()) << id.first;
|
||||
if (iter->second.node == nullptr) {
|
||||
// Input hasn't been created yet, indicating it's a backedge.
|
||||
continue;
|
||||
}
|
||||
inherits_deps = true;
|
||||
}
|
||||
if (inherits_deps) return;
|
||||
|
||||
// node_def either has no inputs or all remapped inputs, add the control
|
||||
// dependencies
|
||||
for (const string& control_dep : opts_.control_dependencies) {
|
||||
string input = TensorId(control_dep, Graph::kControlSlot).ToString();
|
||||
const protobuf::RepeatedPtrField<string>& inputs = node_def->input();
|
||||
if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) {
|
||||
// Control dependency already exists
|
||||
continue;
|
||||
}
|
||||
node_def->add_input(input);
|
||||
input_already_exists->push_back(true);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphConstructor::AddPrefixToNodeDef(
|
||||
const std::vector<bool>& input_remapped, NodeDef* node_def) {
|
||||
const std::vector<bool>& input_already_exists, NodeDef* node_def) {
|
||||
const string& prefix = opts_.prefix;
|
||||
if (prefix.empty()) return;
|
||||
node_def->set_name(strings::StrCat(prefix, node_def->name()));
|
||||
@ -502,7 +570,7 @@ void GraphConstructor::AddPrefixToNodeDef(
|
||||
StringPiece input(node_def->input(i));
|
||||
// Skip remapped inputs (which already exist in g_ and are not being
|
||||
// imported)
|
||||
if (input_remapped[i]) continue;
|
||||
if (input_already_exists[i]) continue;
|
||||
if (input.Consume("^")) {
|
||||
node_def->set_input(i, strings::StrCat("^", prefix, input));
|
||||
} else {
|
||||
@ -540,7 +608,12 @@ Status GraphConstructor::Convert() {
|
||||
NodeDef imported_node_def;
|
||||
const NodeDef* node_def;
|
||||
|
||||
std::vector<bool> input_remapped(original_node_def.input_size(), false);
|
||||
// input_already_exists[i] is true iff the i-th input of the node we're
|
||||
// importing refers to a preexisting node in g_ (i.e. input[i] existed prior
|
||||
// to importing gdef_). Conversely, input_already_exists[i] is false iff
|
||||
// the input refers to a node in gdef_.
|
||||
std::vector<bool> input_already_exists(original_node_def.input_size(),
|
||||
false);
|
||||
|
||||
if (opts_.importing) {
|
||||
// TODO(ashankar): The line below means an additional copy of the NodeDef,
|
||||
@ -549,7 +622,11 @@ Status GraphConstructor::Convert() {
|
||||
// GraphDef* and avoid the copying.
|
||||
imported_node_def = original_node_def;
|
||||
if (!opts_.input_map.empty()) {
|
||||
RemapNodeDefInputs(&imported_node_def, &input_remapped);
|
||||
RemapNodeDefInputs(&imported_node_def, &input_already_exists);
|
||||
}
|
||||
if (!opts_.control_dependencies.empty()) {
|
||||
// Note that input_already_exists can grow here
|
||||
AddControlDependencies(&imported_node_def, &input_already_exists);
|
||||
}
|
||||
node_def = &imported_node_def;
|
||||
} else {
|
||||
@ -562,7 +639,7 @@ Status GraphConstructor::Convert() {
|
||||
Node* src_node;
|
||||
int src_index;
|
||||
|
||||
if (!input_remapped[i]) {
|
||||
if (!input_already_exists[i]) {
|
||||
// Locate input in newly-imported nodes
|
||||
auto iter = gdef_nodes_.find(id.first);
|
||||
DCHECK(iter != gdef_nodes_.end()) << id.first;
|
||||
@ -570,7 +647,7 @@ Status GraphConstructor::Convert() {
|
||||
src_index = id.second;
|
||||
if (src_node == nullptr) has_data_back_edge = true;
|
||||
} else {
|
||||
// Input was remapped according to input_map
|
||||
// Input refers to preexistng node in graph
|
||||
auto iter = existing_nodes_.find(id.first);
|
||||
DCHECK(iter != existing_nodes_.end()) << id.first;
|
||||
src_node = iter->second;
|
||||
@ -595,7 +672,7 @@ Status GraphConstructor::Convert() {
|
||||
|
||||
Node* node;
|
||||
if (opts_.importing) {
|
||||
AddPrefixToNodeDef(input_remapped, &imported_node_def);
|
||||
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
|
||||
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MakeNode(*node_def, &node));
|
||||
|
@ -89,6 +89,14 @@ struct ImportGraphDefOptions {
|
||||
// TODO(skyewm): add functionality to retrieve unused `input_map` keys
|
||||
std::map<TensorId, TensorId> input_map;
|
||||
|
||||
// The names of existing nodes in `g` that the imported graph should have
|
||||
// control dependencies on.
|
||||
//
|
||||
// Note that to avoid creating many redundant control edges, ImportGraphDef()
|
||||
// won't add control edges to nodes that will inherit the dependencies from
|
||||
// other nodes in `gdef`.
|
||||
std::vector<string> control_dependencies;
|
||||
|
||||
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
|
||||
// with ops that are not defined in the binary calling ImportGraphDef.
|
||||
// Similar to the producer_op_list argument to import_graph_def in the
|
||||
|
@ -117,8 +117,9 @@ class GraphConstructorTest : public ::testing::Test {
|
||||
bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
|
||||
for (const Edge* e : graph_.edges()) {
|
||||
if (e->src()->name() == src && e->src_output() == src_out &&
|
||||
e->dst()->name() == dst && e->dst_input() == dst_in)
|
||||
e->dst()->name() == dst && e->dst_input() == dst_in) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1198,6 +1199,133 @@ versions {
|
||||
EXPECT_EQ(Status::OK(), s) << s;
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
|
||||
ShapeRefiner refiner(graph_.op_registry());
|
||||
|
||||
// Populate graph with nodes we'll use in control deps and input map
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }",
|
||||
ImportGraphDefOptions(), &refiner);
|
||||
|
||||
ImportGraphDefOptions opts;
|
||||
opts.control_dependencies = {"W1", "W2"};
|
||||
opts.prefix = "import";
|
||||
opts.input_map[TensorId("W1", -1)] = TensorId("W1", -1);
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node { name: 'W1' op: 'TestParams' }
|
||||
node { name: 'input' op: 'TestInput' }
|
||||
node { name: 'input2' op: 'TestInput' input: [ '^W1' ] }
|
||||
node { name: 't1' op: 'TestMul' input: [ 'input:0', 'input:1' ] }
|
||||
)EOF",
|
||||
opts, &refiner);
|
||||
|
||||
// Sanity checks
|
||||
EXPECT_TRUE(HasNode("import/W1"));
|
||||
EXPECT_TRUE(HasNode("import/input"));
|
||||
EXPECT_TRUE(HasNode("import/input2"));
|
||||
EXPECT_TRUE(HasNode("import/t1"));
|
||||
|
||||
EXPECT_TRUE(HasControlEdge("W1", "import/W1"));
|
||||
EXPECT_TRUE(HasControlEdge("W2", "import/W1"));
|
||||
|
||||
EXPECT_TRUE(HasControlEdge("W1", "import/input"));
|
||||
EXPECT_TRUE(HasControlEdge("W2", "import/input"));
|
||||
|
||||
// Test that t1 doesn't have redundant control edges
|
||||
EXPECT_FALSE(HasControlEdge("W1", "import/t1"));
|
||||
EXPECT_FALSE(HasControlEdge("W2", "import/t1"));
|
||||
EXPECT_TRUE(HasEdge("import/input", 0, "import/t1", 0));
|
||||
EXPECT_TRUE(HasEdge("import/input", 1, "import/t1", 1));
|
||||
|
||||
// Test that input2 has control edges since its only input was remapped
|
||||
EXPECT_TRUE(HasControlEdge("W1", "import/input2"));
|
||||
EXPECT_TRUE(HasControlEdge("W2", "import/input2"));
|
||||
EXPECT_FALSE(HasControlEdge("import/W1", "import/input2"));
|
||||
|
||||
// Test that node defs are consistent with graph
|
||||
Node* w1 = FindNode("import/W1");
|
||||
ASSERT_EQ(w1->def().input_size(), 2);
|
||||
EXPECT_EQ(w1->def().input(0), "^W1");
|
||||
EXPECT_EQ(w1->def().input(1), "^W2");
|
||||
|
||||
Node* input = FindNode("import/input");
|
||||
ASSERT_EQ(input->def().input_size(), 2);
|
||||
EXPECT_EQ(input->def().input(0), "^W1");
|
||||
EXPECT_EQ(input->def().input(1), "^W2");
|
||||
|
||||
Node* input2 = FindNode("import/input2");
|
||||
ASSERT_EQ(input2->def().input_size(), 2);
|
||||
EXPECT_EQ(input2->def().input(0), "^W1");
|
||||
EXPECT_EQ(input2->def().input(1), "^W2");
|
||||
|
||||
Node* t1 = FindNode("import/t1");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
EXPECT_EQ(t1->def().input(0), "import/input:0");
|
||||
EXPECT_EQ(t1->def().input(1), "import/input:1");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
|
||||
ShapeRefiner refiner(graph_.op_registry());
|
||||
|
||||
// Populate graph with nodes we'll use in control deps and input map
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }",
|
||||
ImportGraphDefOptions(), &refiner);
|
||||
|
||||
ImportGraphDefOptions opts;
|
||||
opts.control_dependencies.push_back("W1");
|
||||
// Use input_map to ensure the cycle doesn't inherit the control deps from
|
||||
// new_input
|
||||
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
|
||||
|
||||
// ImportGraphDef only allows backedges into merge nodes (since backedges are
|
||||
// only expected in while loops)
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node { name: 'new_input' op: 'TestInput' }
|
||||
node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 't1:0' ]
|
||||
attr { key: "N" value: { i: 2 } }
|
||||
attr { key: "T" value: { type: DT_FLOAT } } }
|
||||
node { name: 't1' op: 'TestMul' input: [ 'merge:0', 'merge:0' ] }
|
||||
)EOF",
|
||||
opts, &refiner);
|
||||
|
||||
EXPECT_TRUE(HasNode("new_input"));
|
||||
EXPECT_TRUE(HasNode("merge"));
|
||||
EXPECT_TRUE(HasNode("t1"));
|
||||
|
||||
// Sanity check we created cycle
|
||||
EXPECT_TRUE(HasEdge("merge", 0, "t1", 0));
|
||||
EXPECT_TRUE(HasEdge("t1", 0, "merge", 1));
|
||||
|
||||
// Test that control dep was added to exactly one node of cycle
|
||||
EXPECT_TRUE(HasControlEdge("W1", "merge"));
|
||||
EXPECT_FALSE(HasControlEdge("W1", "t1"));
|
||||
|
||||
// Test that node defs are consistent with graph
|
||||
Node* merge = FindNode("merge");
|
||||
ASSERT_EQ(merge->def().input_size(), 3);
|
||||
EXPECT_EQ(merge->def().input(0), "input:0");
|
||||
EXPECT_EQ(merge->def().input(1), "t1:0");
|
||||
EXPECT_EQ(merge->def().input(2), "^W1");
|
||||
|
||||
Node* t1 = FindNode("t1");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
EXPECT_EQ(t1->def().input(0), "merge:0");
|
||||
EXPECT_EQ(t1->def().input(1), "merge:0");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) {
|
||||
// Control dep that isn't in graph def
|
||||
ImportGraphDefOptions opts;
|
||||
opts.control_dependencies.push_back("W1");
|
||||
ExpectError("node { name: 'W1' op: 'TestParams' }", opts,
|
||||
{"node 'W1' in control_dependencies does not exist in graph"});
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ErrorsDoNoChangeTheGraph) {
|
||||
GraphDef def;
|
||||
NodeDefBuilder("scope/A", "TestParams").Finalize(def.add_node());
|
||||
|
@ -3721,6 +3721,8 @@ filegroup(
|
||||
# See b/29213790
|
||||
"scatter_nd_op*",
|
||||
"sparse_matmul_op.*",
|
||||
# Lib CURL is not supported on Android.
|
||||
"bigquery*",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_kernel_library",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
@ -30,6 +31,24 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "bigquery_reader_ops",
|
||||
srcs = [
|
||||
"bigquery_reader_ops.cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bigquery_table_accessor",
|
||||
":bigquery_table_partition_proto_cc",
|
||||
"//tensorflow/core:cloud_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:reader_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bigquery_table_accessor",
|
||||
srcs = [
|
||||
|
193
tensorflow/core/kernels/cloud/bigquery_reader_ops.cc
Normal file
193
tensorflow/core/kernels/cloud/bigquery_reader_ops.cc
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h"
|
||||
#include "tensorflow/core/kernels/reader_base.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
|
||||
|
||||
// This is a helper function for reading table attributes from context.
|
||||
Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
|
||||
string* dataset_id, string* table_id,
|
||||
int64* timestamp_millis, std::vector<string>* columns,
|
||||
string* test_end_point) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Note that overriden methods with names ending in "Locked" are called by
|
||||
// ReaderBase while a mutex is held.
|
||||
// See comments for ReaderBase.
|
||||
class BigQueryReader : public ReaderBase {
|
||||
public:
|
||||
explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
|
||||
const string& node_name)
|
||||
: ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
|
||||
bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
|
||||
|
||||
Status OnWorkStartedLocked() override {
|
||||
BigQueryTablePartition partition;
|
||||
if (!partition.ParseFromString(current_work())) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not parse work as as valid partition.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*at_end = false;
|
||||
*produced = false;
|
||||
if (bigquery_table_accessor_->Done()) {
|
||||
*at_end = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Example example;
|
||||
int64 row_id;
|
||||
TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
|
||||
|
||||
*key = std::to_string(row_id);
|
||||
*value = example.SerializeAsString();
|
||||
*produced = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Not owned.
|
||||
BigQueryTableAccessor* bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
class BigQueryReaderOp : public ReaderOpKernel {
|
||||
public:
|
||||
explicit BigQueryReaderOp(OpKernelConstruction* context)
|
||||
: ReaderOpKernel(context) {
|
||||
string table_id;
|
||||
string project_id;
|
||||
string dataset_id;
|
||||
int64 timestamp_millis;
|
||||
std::vector<string> columns;
|
||||
string test_end_point;
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
|
||||
×tamp_millis, &columns, &test_end_point));
|
||||
OP_REQUIRES_OK(context,
|
||||
BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis,
|
||||
kDefaultRowBufferSize, test_end_point, columns,
|
||||
BigQueryTablePartition(), &bigquery_table_accessor_));
|
||||
|
||||
SetReaderFactory([this]() {
|
||||
return new BigQueryReader(bigquery_table_accessor_.get(), name());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
|
||||
BigQueryReaderOp);
|
||||
|
||||
class GenerateBigQueryReaderPartitionsOp : public OpKernel {
|
||||
public:
|
||||
explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string project_id;
|
||||
string dataset_id;
|
||||
string table_id;
|
||||
int64 timestamp_millis;
|
||||
std::vector<string> columns;
|
||||
string test_end_point;
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
|
||||
×tamp_millis, &columns, &test_end_point));
|
||||
OP_REQUIRES_OK(context,
|
||||
BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis,
|
||||
kDefaultRowBufferSize, test_end_point, columns,
|
||||
BigQueryTablePartition(), &bigquery_table_accessor_));
|
||||
OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
|
||||
OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
|
||||
total_num_rows_, num_partitions_);
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({num_partitions_}),
|
||||
&output_tensor));
|
||||
|
||||
auto output = output_tensor->template flat<string>();
|
||||
for (int64 i = 0; i < num_partitions_; ++i) {
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(i * partition_size);
|
||||
partition.set_end_index(
|
||||
std::min(total_num_rows_, (i + 1) * partition_size) - 1);
|
||||
output(i) = partition.SerializeAsString();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Status InitializeTotalNumberOfRows() {
|
||||
total_num_rows_ = bigquery_table_accessor_->total_num_rows();
|
||||
if (total_num_rows_ <= 0) {
|
||||
return errors::FailedPrecondition("Invalid total number of rows.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
|
||||
if (num_partitions_ <= 0) {
|
||||
return errors::FailedPrecondition("Invalid number of partitions.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 num_partitions_;
|
||||
int64 total_num_rows_;
|
||||
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
|
||||
GenerateBigQueryReaderPartitionsOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
|
||||
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
@ -23,6 +22,15 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
|
||||
const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
|
||||
|
||||
bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
|
||||
if (partition.end_index() != -1 &&
|
||||
partition.end_index() < partition.start_index()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ParseJson(StringPiece json, Json::Value* result) {
|
||||
Json::Reader reader;
|
||||
@ -92,17 +100,18 @@ Status ParseColumnType(const string& type,
|
||||
|
||||
Status BigQueryTableAccessor::New(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor) {
|
||||
return New(project_id, dataset_id, table_id, timestamp_millis,
|
||||
row_buffer_size, columns, partition, nullptr, nullptr, accessor);
|
||||
row_buffer_size, end_point, columns, partition, nullptr, nullptr,
|
||||
accessor);
|
||||
}
|
||||
|
||||
Status BigQueryTableAccessor::New(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor) {
|
||||
@ -110,14 +119,16 @@ Status BigQueryTableAccessor::New(
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use zero or negative timestamp to query a table.");
|
||||
}
|
||||
const string& big_query_end_point =
|
||||
end_point.empty() ? kBigQueryEndPoint : end_point;
|
||||
if (auth_provider == nullptr && http_request_factory == nullptr) {
|
||||
accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id,
|
||||
timestamp_millis, row_buffer_size,
|
||||
columns, partition));
|
||||
accessor->reset(new BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
big_query_end_point, columns, partition));
|
||||
} else {
|
||||
accessor->reset(new BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
columns, partition, std::move(auth_provider),
|
||||
big_query_end_point, columns, partition, std::move(auth_provider),
|
||||
std::move(http_request_factory)));
|
||||
}
|
||||
return (*accessor)->ReadSchema();
|
||||
@ -125,11 +136,11 @@ Status BigQueryTableAccessor::New(
|
||||
|
||||
BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition)
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition)
|
||||
: BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
columns, partition,
|
||||
end_point, columns, partition,
|
||||
std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
|
||||
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) {
|
||||
row_buffer_.resize(row_buffer_size);
|
||||
@ -137,15 +148,16 @@ BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
|
||||
BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory)
|
||||
: project_id_(project_id),
|
||||
dataset_id_(dataset_id),
|
||||
table_id_(table_id),
|
||||
timestamp_millis_(timestamp_millis),
|
||||
columns_(columns),
|
||||
columns_(columns.begin(), columns.end()),
|
||||
bigquery_end_point_(end_point),
|
||||
partition_(partition),
|
||||
auth_provider_(std::move(auth_provider)),
|
||||
http_request_factory_(std::move(http_request_factory)) {
|
||||
@ -153,10 +165,14 @@ BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
Reset();
|
||||
}
|
||||
|
||||
void BigQueryTableAccessor::SetPartition(
|
||||
Status BigQueryTableAccessor::SetPartition(
|
||||
const BigQueryTablePartition& partition) {
|
||||
if (partition.start_index() < 0) {
|
||||
return errors::InvalidArgument("Start index cannot be negative.");
|
||||
}
|
||||
partition_ = partition;
|
||||
Reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BigQueryTableAccessor::Reset() {
|
||||
@ -172,7 +188,8 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
|
||||
// If the next row is already fetched and cached, return the row from the
|
||||
// buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
|
||||
if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) {
|
||||
if (next_row_in_buffer_ != -1 &&
|
||||
next_row_in_buffer_ < ComputeMaxResultsArg()) {
|
||||
*row_id = first_buffered_row_index_ + next_row_in_buffer_;
|
||||
*example = row_buffer_[next_row_in_buffer_];
|
||||
next_row_in_buffer_++;
|
||||
@ -190,12 +207,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
// we use the page token (which returns rows faster).
|
||||
if (!next_page_token_.empty()) {
|
||||
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
|
||||
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
|
||||
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
|
||||
"&pageToken=", request->EscapeString(next_page_token_))));
|
||||
first_buffered_row_index_ += row_buffer_.size();
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
|
||||
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
|
||||
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
|
||||
"&startIndex=", first_buffered_row_index_)));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
|
||||
@ -222,6 +239,18 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
|
||||
if (partition_.end_index() == -1) {
|
||||
return row_buffer_.size();
|
||||
}
|
||||
if (IsPartitionEmpty(partition_)) {
|
||||
return 0;
|
||||
}
|
||||
return std::min(static_cast<int64>(row_buffer_.size()),
|
||||
static_cast<int64>(partition_.end_index() -
|
||||
partition_.start_index() + 1));
|
||||
}
|
||||
|
||||
Status BigQueryTableAccessor::ParseColumnValues(
|
||||
const Json::Value& value, const SchemaNode& root_schema_node,
|
||||
Example* example) {
|
||||
@ -364,21 +393,17 @@ Status BigQueryTableAccessor::AppendValueToExample(
|
||||
|
||||
string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
|
||||
HttpRequest request;
|
||||
return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/",
|
||||
return strings::StrCat(bigquery_end_point_, "/projects/",
|
||||
request.EscapeString(project_id_), "/datasets/",
|
||||
request.EscapeString(dataset_id_), "/tables/",
|
||||
request.EscapeString(table_id_), "/");
|
||||
}
|
||||
|
||||
string BigQueryTableAccessor::FullTableName() {
|
||||
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
|
||||
timestamp_millis_);
|
||||
}
|
||||
|
||||
bool BigQueryTableAccessor::Done() {
|
||||
return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
|
||||
IsPartitionEmpty(partition_) ||
|
||||
(partition_.end_index() != -1 &&
|
||||
partition_.end_index() <=
|
||||
partition_.end_index() <
|
||||
first_buffered_row_index_ + next_row_in_buffer_);
|
||||
}
|
||||
|
||||
|
@ -55,16 +55,23 @@ class BigQueryTableAccessor {
|
||||
};
|
||||
|
||||
/// \brief Creates a new BigQueryTableAccessor object.
|
||||
//
|
||||
// We do not allow relative (negative or zero) snapshot times here since we
|
||||
// want to have a consistent snapshot of the table for the lifetime of this
|
||||
// object.
|
||||
// Use end_point if you want to connect to a different end point than the
|
||||
// official BigQuery end point. Otherwise send an empty string.
|
||||
static Status New(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor);
|
||||
|
||||
/// \brief Starts reading a new partition.
|
||||
void SetPartition(const BigQueryTablePartition& partition);
|
||||
Status SetPartition(const BigQueryTablePartition& partition);
|
||||
|
||||
/// \brief Returns false if there are more rows available in the current
|
||||
/// \brief Returns true if there are more rows available in the current
|
||||
/// partition.
|
||||
bool Done();
|
||||
|
||||
@ -74,9 +81,11 @@ class BigQueryTableAccessor {
|
||||
/// in the BigQuery service.
|
||||
Status ReadRow(int64* row_id, Example* example);
|
||||
|
||||
/// \brief Returns total number of rows.
|
||||
/// \brief Returns total number of rows in the table.
|
||||
int64 total_num_rows() { return total_num_rows_; }
|
||||
|
||||
virtual ~BigQueryTableAccessor() {}
|
||||
|
||||
private:
|
||||
friend class BigQueryTableAccessorTest;
|
||||
|
||||
@ -95,7 +104,8 @@ class BigQueryTableAccessor {
|
||||
/// these two variables.
|
||||
static Status New(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory,
|
||||
@ -104,14 +114,16 @@ class BigQueryTableAccessor {
|
||||
/// \brief Constructs an object for a given table and partition.
|
||||
BigQueryTableAccessor(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition);
|
||||
|
||||
/// Used for unit testing.
|
||||
BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
const string& end_point, const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory);
|
||||
|
||||
@ -132,7 +144,7 @@ class BigQueryTableAccessor {
|
||||
Status AppendValueToExample(const string& column_name,
|
||||
const Json::Value& column_value,
|
||||
const BigQueryTableAccessor::ColumnType type,
|
||||
Example* ex);
|
||||
Example* example);
|
||||
|
||||
/// \brief Resets internal counters for reading a partition.
|
||||
void Reset();
|
||||
@ -140,25 +152,28 @@ class BigQueryTableAccessor {
|
||||
/// \brief Helper function that returns BigQuery http endpoint prefix.
|
||||
string BigQueryUriPrefix();
|
||||
|
||||
/// \brief Computes the maxResults arg to send to BigQuery.
|
||||
int64 ComputeMaxResultsArg();
|
||||
|
||||
/// \brief Returns full name of the underlying table name.
|
||||
string FullTableName();
|
||||
string FullTableName() {
|
||||
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
|
||||
timestamp_millis_);
|
||||
}
|
||||
|
||||
const string project_id_;
|
||||
const string dataset_id_;
|
||||
const string table_id_;
|
||||
|
||||
// Snapshot timestamp.
|
||||
//
|
||||
// Indicates a snapshot of the table in milliseconds since the epoch.
|
||||
//
|
||||
// We do not allow relative (negative or zero) times here since we want to
|
||||
// have a consistent snapshot of the table for the lifetime of this object.
|
||||
// For more details, see 'Table Decorators' in BigQuery documentation.
|
||||
const int64 timestamp_millis_;
|
||||
|
||||
// Columns that should be read. Empty means all columns.
|
||||
const std::set<string> columns_;
|
||||
|
||||
// HTTP address of BigQuery end point to use.
|
||||
const string bigquery_end_point_;
|
||||
|
||||
// Describes the portion of the table that we are currently accessing.
|
||||
BigQueryTablePartition partition_;
|
||||
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestProject[] = "test-project";
|
||||
@ -69,10 +68,10 @@ class BigQueryTableAccessorTest : public ::testing::Test {
|
||||
Status CreateTableAccessor(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size,
|
||||
const std::set<string>& columns,
|
||||
const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition) {
|
||||
return BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "",
|
||||
columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests_)),
|
||||
@ -197,7 +196,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) {
|
||||
kTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -227,7 +226,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) {
|
||||
kTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{"bool_field", "rec_field.float_field"},
|
||||
partition));
|
||||
@ -258,7 +257,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) {
|
||||
kTestRowWithNulls));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -288,7 +287,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) {
|
||||
kBrokenTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -357,7 +356,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
kSampleSchema));
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
|
||||
"datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n"
|
||||
"datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n"
|
||||
"Auth Token: fake_token\n",
|
||||
kTestTwoRows));
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
@ -374,7 +373,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(0);
|
||||
partition.set_end_index(1);
|
||||
partition.set_end_index(0);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
|
||||
{}, partition));
|
||||
|
||||
@ -396,7 +395,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
1234);
|
||||
|
||||
partition.set_start_index(0);
|
||||
partition.set_end_index(2);
|
||||
partition.set_end_index(1);
|
||||
accessor_->SetPartition(partition);
|
||||
TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
|
||||
EXPECT_EQ(0, row_id);
|
||||
@ -410,4 +409,23 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
2222);
|
||||
}
|
||||
|
||||
TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) {
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
|
||||
"datasets/test-dataset/tables/test-table/\n"
|
||||
"Auth Token: fake_token\n",
|
||||
kSampleSchema));
|
||||
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
EXPECT_TRUE(accessor_->Done());
|
||||
|
||||
int64 row_id;
|
||||
Example example;
|
||||
EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,48 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
class InplaceOpBase : public OpKernel {
|
||||
public:
|
||||
explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto value = ctx->input(0);
|
||||
auto loc = ctx->input(1);
|
||||
auto update = ctx->input(2);
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()),
|
||||
errors::InvalidArgument("loc must be a vector. ",
|
||||
loc.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, value.dims() == update.dims(),
|
||||
errors::InvalidArgument("value and update shape doesn't match: ",
|
||||
value.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
for (int i = 1; i < value.dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, value.dim_size(i) == update.dim_size(i),
|
||||
errors::InvalidArgument("value and update shape doesn't match ",
|
||||
value.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
}
|
||||
OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0),
|
||||
errors::InvalidArgument("loc and update shape doesn't match: ",
|
||||
loc.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
|
||||
Tensor output = value; // This creates an alias intentionally.
|
||||
OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output));
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value,
|
||||
const Tensor& loc, Tensor* output) = 0;
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
@ -111,6 +71,48 @@ Status DoInplace(const CPUDevice& d, InplaceOpType op, const Tensor& value,
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(apassos): validate the shapes better.
|
||||
class InplaceOpBase : public OpKernel {
|
||||
public:
|
||||
explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto value = ctx->input(0);
|
||||
auto loc = ctx->input(1);
|
||||
auto update = ctx->input(2);
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()),
|
||||
errors::InvalidArgument("loc must be a vector. ",
|
||||
loc.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, value.dims() == update.dims(),
|
||||
errors::InvalidArgument("value and update shape doesn't match: ",
|
||||
value.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
for (int i = 1; i < value.dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, value.dim_size(i) == update.dim_size(i),
|
||||
errors::InvalidArgument("value and update shape doesn't match ",
|
||||
value.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
}
|
||||
OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0),
|
||||
errors::InvalidArgument("loc and update shape doesn't match: ",
|
||||
loc.shape().DebugString(), " vs. ",
|
||||
update.shape().DebugString()));
|
||||
|
||||
Tensor output = value; // This creates an alias intentionally.
|
||||
OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output));
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value,
|
||||
const Tensor& loc, Tensor* output) = 0;
|
||||
};
|
||||
|
||||
template <typename Device, functor::InplaceOpType op>
|
||||
class InplaceOp : public InplaceOpBase {
|
||||
public:
|
||||
@ -159,21 +161,27 @@ class EmptyOp : public OpKernel {
|
||||
bool init_;
|
||||
};
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceUpdate").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<CPUDevice, functor::I_UPDATE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<CPUDevice, functor::I_ADD>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceSubtract").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<CPUDevice, functor::I_SUB>);
|
||||
class FailureKernel : public OpKernel {
|
||||
public:
|
||||
explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
errors::Internal("Found instance of parallel_stack which "
|
||||
"could not be properly replaced."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext*) {}
|
||||
};
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
InplaceOp<CPUDevice, functor::I_UPDATE>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER)
|
||||
#undef REGISTER
|
||||
|
||||
#define REGISTER_EMPTY(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Empty") \
|
||||
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
@ -182,12 +190,19 @@ TF_CALL_NUMBER_TYPES(REGISTER)
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
|
||||
#undef REGISTER_EMPTY
|
||||
|
||||
#define REGISTER_PARALLEL_CONCAT(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
FailureKernel);
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT);
|
||||
#undef REGISTER_PARALLEL_CONCAT
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#define REGISTER_EMPTY(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Empty") \
|
||||
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
@ -195,23 +210,25 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY)
|
||||
#undef REGISTER_EMPTY
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<GPUDevice, functor::I_UPDATE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<GPUDevice, functor::I_ADD>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("InplaceSubtract").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
InplaceOp<GPUDevice, functor::I_SUB>);
|
||||
#define REGISTER_PARALLEL_CONCAT(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
FailureKernel);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT);
|
||||
#undef REGISTER_PARALLEL_CONCAT
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
InplaceOp<GPUDevice, functor::I_UPDATE>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER)
|
||||
#undef REGISTER
|
||||
|
||||
// Register versions that operate on int32 data on the CPU even though the op
|
||||
// has been placed on the GPU
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
|
||||
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("value")
|
||||
.HostMemory("loc")
|
||||
@ -219,24 +236,7 @@ REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
InplaceOp<CPUDevice, functor::I_UPDATE>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("value")
|
||||
.HostMemory("loc")
|
||||
.HostMemory("update")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
InplaceOp<CPUDevice, functor::I_ADD>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InplaceSubtract")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("value")
|
||||
.HostMemory("loc")
|
||||
.HostMemory("update")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
InplaceOp<CPUDevice, functor::I_SUB>);
|
||||
#endif
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
||||
|
@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// tensor value was provided for paddings_t; doublecheck n_dim value is the
|
||||
// same.
|
||||
const auto num_dims = c->Value(n_dim);
|
||||
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
|
||||
const int64 num_dims = paddings_t->shape().dim_size(0);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
|
||||
|
||||
if (paddings_t->dtype() == DT_INT32) {
|
||||
return PadKnown<int32>(c, input, paddings_t, num_dims);
|
||||
@ -165,6 +164,71 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("ParallelConcat")
|
||||
.Input("values: N * T")
|
||||
.Output("output: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: type")
|
||||
.Attr("shape: shape")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Validate that the shape attr is correct.
|
||||
TensorShapeProto passed_shape_proto;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &passed_shape_proto));
|
||||
ShapeHandle passed_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->MakeShapeFromShapeProto(passed_shape_proto, &passed_shape));
|
||||
if (!c->FullyDefined(passed_shape)) {
|
||||
return errors::InvalidArgument("shape attr must be fully defined.");
|
||||
}
|
||||
ShapeHandle cur;
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(
|
||||
passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
|
||||
&cur));
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
if (!c->FullyDefined(c->input(i))) {
|
||||
return errors::InvalidArgument(
|
||||
"All input shapes must be fully defined.");
|
||||
}
|
||||
DimensionHandle unused;
|
||||
if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
|
||||
return errors::InvalidArgument("Size of first dimension must be 1.");
|
||||
}
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
|
||||
"From merging shape ", i,
|
||||
" with other shapes.");
|
||||
}
|
||||
|
||||
c->set_output(0, passed_shape);
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Concatenates a list of `N` tensors along the first dimension.
|
||||
|
||||
The input tensors are all required to have size 1 in the first dimension.
|
||||
|
||||
For example:
|
||||
|
||||
```prettyprint
|
||||
# 'x' is [[1, 4]]
|
||||
# 'y' is [[2, 5]]
|
||||
# 'z' is [[3, 6]]
|
||||
parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
|
||||
```
|
||||
|
||||
The difference between concat and parallel_concat is that concat requires all
|
||||
of the inputs be computed before the operation will begin but doesn't require
|
||||
that the input shapes be known during graph construction. Parallel concat
|
||||
will copy pieces of the input into the output as they become available, in
|
||||
some situations this can provide a performance benefit.
|
||||
|
||||
values: Tensors to be concatenated. All must have size 1 in the first dimension
|
||||
and same shape.
|
||||
output: The concatenated tensor.
|
||||
shape: the final shape of the result; should be equal to the shapes of any input
|
||||
but with the number of input values in the first dimension.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Pack")
|
||||
.Input("values: N * T")
|
||||
.Output("output: T")
|
||||
@ -440,7 +504,8 @@ REGISTER_OP("SplitV")
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
} else {
|
||||
// Determine the output shape if split dimension and split sizes are known
|
||||
// Determine the output shape if split dimension and split sizes are
|
||||
// known.
|
||||
int64 split_dim = c->Value(split_dimension);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
|
||||
std::vector<int64> data;
|
||||
@ -451,12 +516,12 @@ REGISTER_OP("SplitV")
|
||||
}
|
||||
if (num_outputs != data.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Length of size_splits should be equal to num_outputs");
|
||||
"Length of size_splits should be equal to num_outputs");
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_shape = c->UnknownShapeOfRank(rank);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
|
||||
c->MakeDim(data[i]), &output_shape));
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
}
|
||||
@ -1160,7 +1225,7 @@ Equivalent to np.full
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Empty")
|
||||
REGISTER_OP("_ParallelConcatStart")
|
||||
.Input("shape: Tshape")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: type")
|
||||
@ -1186,7 +1251,7 @@ output: An empty Tensor of the specified type.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("InplaceUpdate")
|
||||
REGISTER_OP("_ParallelConcatUpdate")
|
||||
.Input("value: T")
|
||||
.Input("loc: Tshape")
|
||||
.Input("update: T")
|
||||
@ -1225,86 +1290,6 @@ update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
|
||||
output: `value` that has been updated accordingly.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("InplaceAdd")
|
||||
.Input("value: T")
|
||||
.Input("loc: Tshape")
|
||||
.Input("update: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Updates input `value` at `loc` by adding `update` elementwise.
|
||||
|
||||
If `loc` is None, `value` and `update` must be the same size.
|
||||
```
|
||||
value += update
|
||||
```
|
||||
|
||||
If `loc` is a scalar, `value` has rank 1 higher than `update`
|
||||
```
|
||||
value[i, :] += update
|
||||
```
|
||||
|
||||
If `loc` is a vector, `value` has the same rank as `update`
|
||||
```
|
||||
value[loc, :] += update
|
||||
```
|
||||
|
||||
If you use this function you will almost certainly want to add
|
||||
a control dependency as done in the implementation of parallel_stack to
|
||||
avoid race conditions.
|
||||
|
||||
value: A `Tensor` object that will be updated in-place.
|
||||
loc: A scalar or 1-D `Tensor` indicating the indices of the first dimension
|
||||
such that value[loc, :] is updated.
|
||||
update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
|
||||
otherwise of rank equal to `value` that contains the new values
|
||||
that will be added to `value`.
|
||||
output: `value` where `update` has been added as appropriate.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("InplaceSubtract")
|
||||
.Input("value: T")
|
||||
.Input("loc: Tshape")
|
||||
.Input("update: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Updates input `value` at `loc` by subtracting `update` elementwise.
|
||||
|
||||
If `loc` is None, `value` and `update` must be the same size.
|
||||
```
|
||||
value -= update
|
||||
```
|
||||
|
||||
If `loc` is a scalar, `value` has rank 1 higher than `update`
|
||||
```
|
||||
value[i, :] -= update
|
||||
```
|
||||
|
||||
If `loc` is a vector, `value` has the same rank as `update`
|
||||
```
|
||||
value[loc, :] -= update
|
||||
```
|
||||
|
||||
If you use this function you will almost certainly want to add
|
||||
a control dependency as done in the implementation of parallel_stack to
|
||||
avoid race conditions.
|
||||
|
||||
value: A `Tensor` object that will be updated in-place.
|
||||
loc: A scalar or 1-D `Tensor` indicating the indices of the first dimension
|
||||
such that value[loc, :] is updated.
|
||||
update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
|
||||
otherwise of rank equal to `value` that contains the new values
|
||||
that will be subtracted from `value`.
|
||||
output: `value` where `update` has been subtracted as appropriate.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Gather")
|
||||
.Input("params: Tparams")
|
||||
@ -1370,8 +1355,8 @@ REGISTER_OP("GatherNd")
|
||||
if (c->Value(r_dim) > c->Rank(params)) {
|
||||
return errors::InvalidArgument(
|
||||
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
|
||||
c->DebugString(indices), " and params shape: ",
|
||||
c->DebugString(params));
|
||||
c->DebugString(indices),
|
||||
" and params shape: ", c->DebugString(params));
|
||||
}
|
||||
|
||||
// Remove r_dim from indices to get output.
|
||||
@ -1906,12 +1891,12 @@ REGISTER_OP("ReverseSequence")
|
||||
// Validate batch_dim and seq_dim against input.
|
||||
const int32 input_rank = c->Rank(input);
|
||||
if (batch_dim >= input_rank) {
|
||||
return errors::InvalidArgument("batch_dim must be < input rank: ",
|
||||
batch_dim, " vs. ", input_rank);
|
||||
return errors::InvalidArgument(
|
||||
"batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
|
||||
}
|
||||
if (seq_dim >= input_rank) {
|
||||
return errors::InvalidArgument("seq_dim must be < input rank: ",
|
||||
seq_dim, " vs. ", input_rank);
|
||||
return errors::InvalidArgument(
|
||||
"seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
|
||||
}
|
||||
|
||||
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
|
||||
@ -3790,8 +3775,9 @@ REGISTER_OP("SpaceToDepth")
|
||||
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
|
||||
&output_depth));
|
||||
|
||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
||||
output_width, output_depth}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||
output_depth}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -3895,8 +3881,9 @@ REGISTER_OP("DepthToSpace")
|
||||
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
|
||||
true /* evenly_divisible */, &output_depth));
|
||||
|
||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
||||
output_width, output_depth}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||
output_depth}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -4772,8 +4759,9 @@ Status ScatterNdShape(InferenceContext* c) {
|
||||
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"The outer ", outer_dims, " dimensions of indices.shape=",
|
||||
c->DebugString(indices_shape), " must match the outer ", outer_dims,
|
||||
"The outer ", outer_dims,
|
||||
" dimensions of indices.shape=", c->DebugString(indices_shape),
|
||||
" must match the outer ", outer_dims,
|
||||
" dimensions of updates.shape=", c->DebugString(updates_shape),
|
||||
": ", s.error_message());
|
||||
}
|
||||
|
@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
|
||||
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
|
||||
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
|
||||
INFER_OK(op, "?;[3,2]", "[?,?,?]");
|
||||
INFER_OK(op, "?;?", "[?,?,?]");
|
||||
}
|
||||
}
|
||||
|
||||
|
88
tensorflow/core/ops/cloud_ops.cc
Normal file
88
tensorflow/core/ops/cloud_ops.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
/* This file registers all cloud ops. */
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
|
||||
REGISTER_OP("BigQueryReader")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("project_id: string")
|
||||
.Attr("dataset_id: string")
|
||||
.Attr("table_id: string")
|
||||
.Attr("columns: list(string)")
|
||||
.Attr("timestamp_millis: int")
|
||||
.Attr("test_end_point: string = ''")
|
||||
.Output("reader_handle: Ref(string)")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A Reader that outputs rows from a BigQuery table as tensorflow Examples.
|
||||
|
||||
container: If non-empty, this reader is placed in the given container.
|
||||
Otherwise, a default container is used.
|
||||
shared_name: If non-empty, this reader is named in the given bucket
|
||||
with this shared_name. Otherwise, the node name is used instead.
|
||||
project_id: GCP project ID.
|
||||
dataset_id: BigQuery Dataset ID.
|
||||
table_id: Table to read.
|
||||
columns: List of columns to read. Leave empty to read all columns.
|
||||
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
|
||||
(negative or zero) snapshot times are not allowed. For more details, see
|
||||
'Table Decorators' in BigQuery docs.
|
||||
test_end_point: Do not use. For testing purposes only.
|
||||
reader_handle: The handle to reference the Reader.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("GenerateBigQueryReaderPartitions")
|
||||
.Attr("project_id: string")
|
||||
.Attr("dataset_id: string")
|
||||
.Attr("table_id: string")
|
||||
.Attr("columns: list(string)")
|
||||
.Attr("timestamp_millis: int")
|
||||
.Attr("num_partitions: int")
|
||||
.Attr("test_end_point: string = ''")
|
||||
.Output("partitions: string")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Generates serialized partition messages suitable for batch reads.
|
||||
|
||||
This op should not be used directly by clients. Instead, the
|
||||
bigquery_reader_ops.py file defines a clean interface to the reader.
|
||||
|
||||
project_id: GCP project ID.
|
||||
dataset_id: BigQuery Dataset ID.
|
||||
table_id: Table to read.
|
||||
columns: List of columns to read. Leave empty to read all columns.
|
||||
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
|
||||
(negative or zero) snapshot times are not allowed. For more details, see
|
||||
'Table Decorators' in BigQuery docs.
|
||||
num_partitions: Number of partitions to split the table into.
|
||||
test_end_point: Do not use. For testing purposes only.
|
||||
partitions: Serialized table partitions.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -11058,77 +11058,6 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Empty"
|
||||
input_arg {
|
||||
name: "shape"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "init"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Empty"
|
||||
input_arg {
|
||||
name: "shape"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "init"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EncodeBase64"
|
||||
input_arg {
|
||||
@ -14419,114 +14348,6 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "InplaceAdd"
|
||||
input_arg {
|
||||
name: "value"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "InplaceSubtract"
|
||||
input_arg {
|
||||
name: "value"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "InplaceUpdate"
|
||||
input_arg {
|
||||
name: "value"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Inv"
|
||||
input_arg {
|
||||
@ -19650,6 +19471,32 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ParallelConcat"
|
||||
input_arg {
|
||||
name: "values"
|
||||
type_attr: "T"
|
||||
number_attr: "N"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "shape"
|
||||
type: "shape"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ParameterizedTruncatedNormal"
|
||||
input_arg {
|
||||
|
@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
|
||||
width = c->UnknownDim();
|
||||
height = c->UnknownDim();
|
||||
} else {
|
||||
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
|
||||
if (size_tensor->dtype() != DT_INT32) {
|
||||
return errors::InvalidArgument(
|
||||
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
|
||||
"but got ",
|
||||
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
|
||||
" in ", c->DebugString());
|
||||
}
|
||||
auto vec = size_tensor->vec<int32>();
|
||||
height = c->MakeDim(vec(0));
|
||||
width = c->MakeDim(vec(1));
|
||||
@ -74,8 +82,9 @@ Status DecodeImageShapeFn(InferenceContext* c) {
|
||||
channels_dim = c->MakeDim(channels);
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, channels_dim}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, channels_dim}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -555,9 +564,10 @@ REGISTER_OP("DecodeGif")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, 3}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, 3}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
|
@ -6371,48 +6371,6 @@ op {
|
||||
}
|
||||
summary: "Computes gradients for the exponential linear (Elu) operation."
|
||||
}
|
||||
op {
|
||||
name: "Empty"
|
||||
input_arg {
|
||||
name: "shape"
|
||||
description: "1-D `Tensor` indicating the shape of the output."
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "An empty Tensor of the specified type."
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
description: "The element type of the returned tensor."
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "init"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "`bool` indicating whether or not to zero the allocated memory."
|
||||
}
|
||||
summary: "Creates an empty Tensor with shape `shape` and type `dtype`."
|
||||
description: "The memory can optionally be initialized. This is usually useful in\nconjunction with inplace operations."
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EncodeBase64"
|
||||
input_arg {
|
||||
@ -8894,132 +8852,6 @@ op {
|
||||
summary: "Initializes a table from a text file."
|
||||
description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`."
|
||||
}
|
||||
op {
|
||||
name: "InplaceAdd"
|
||||
input_arg {
|
||||
name: "value"
|
||||
description: "A `Tensor` object that will be updated in-place."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated."
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nthat will be added to `value`."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "`value` where `update` has been added as appropriate."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Updates input `value` at `loc` by adding `update` elementwise."
|
||||
description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue += update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] += update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] += update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions."
|
||||
}
|
||||
op {
|
||||
name: "InplaceSubtract"
|
||||
input_arg {
|
||||
name: "value"
|
||||
description: "A `Tensor` object that will be updated in-place."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated."
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nthat will be subtracted from `value`."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "`value` where `update` has been subtracted as appropriate."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Updates input `value` at `loc` by subtracting `update` elementwise."
|
||||
description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue -= update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] -= update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] -= update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions."
|
||||
}
|
||||
op {
|
||||
name: "InplaceUpdate"
|
||||
input_arg {
|
||||
name: "value"
|
||||
description: "A `Tensor` object that will be updated in-place."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "loc"
|
||||
description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated."
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
input_arg {
|
||||
name: "update"
|
||||
description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nfor `value`."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "`value` that has been updated accordingly."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Updates input `value` at `loc` with `update`."
|
||||
description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue = update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] = update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] = update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions."
|
||||
}
|
||||
op {
|
||||
name: "Inv"
|
||||
input_arg {
|
||||
@ -11830,6 +11662,37 @@ op {
|
||||
description: "Variable-size shapes are allowed by setting the corresponding shape dimensions\nto 0 in the shape attr. In this case DequeueMany will pad up to the maximum\nsize of any given element in the minibatch. See below for details."
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ParallelConcat"
|
||||
input_arg {
|
||||
name: "values"
|
||||
description: "Tensors to be concatenated. All must have size 1 in the first dimension\nand same shape."
|
||||
type_attr: "T"
|
||||
number_attr: "N"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "The concatenated tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "shape"
|
||||
type: "shape"
|
||||
description: "the final shape of the result; should be equal to the shapes of any input\nbut with the number of input values in the first dimension."
|
||||
}
|
||||
summary: "Concatenates a list of `N` tensors along the first dimension."
|
||||
description: "The input tensors are all required to have size 1 in the first dimension.\n\nFor example:\n\n```prettyprint\n# \'x\' is [[1, 4]]\n# \'y\' is [[2, 5]]\n# \'z\' is [[3, 6]]\nparallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.\n```\n\nThe difference between concat and parallel_concat is that concat requires all\nof the inputs be computed before the operation will begin but doesn\'t require\nthat the input shapes be known during graph construction. Parallel concat\nwill copy pieces of the input into the output as they become available, in\nsome situations this can provide a performance benefit."
|
||||
}
|
||||
op {
|
||||
name: "ParameterizedTruncatedNormal"
|
||||
input_arg {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
load("@protobuf//:protobuf.bzl", "cc_proto_library")
|
||||
load("@protobuf//:protobuf.bzl", "py_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
|
||||
|
||||
# configure may change the following lines
|
||||
WITH_GCP_SUPPORT = False
|
||||
@ -207,6 +208,24 @@ def tf_additional_core_deps():
|
||||
deps.append("//tensorflow/core/platform/hadoop:hadoop_file_system")
|
||||
return deps
|
||||
|
||||
# TODO(jart, jhseu): Delete when GCP is default on.
|
||||
def tf_additional_cloud_op_deps():
|
||||
deps = []
|
||||
# TODO(hormati): Remove the comments below to enable BigQuery op. The op is
|
||||
# not linked for now because it is under perf testing.
|
||||
#if WITH_GCP_SUPPORT:
|
||||
# deps = if_not_mobile(["//tensorflow/core/kernels/cloud:bigquery_reader_ops"])
|
||||
return deps
|
||||
|
||||
# TODO(jart, jhseu): Delete when GCP is default on.
|
||||
def tf_additional_cloud_kernel_deps():
|
||||
deps = []
|
||||
# TODO(hormati): Remove the comments below to enable BigQuery op. The op is
|
||||
# not linked for now because it is under perf testing.
|
||||
#if WITH_GCP_SUPPORT:
|
||||
# deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
|
||||
return deps
|
||||
|
||||
def tf_additional_plugin_deps():
|
||||
deps = []
|
||||
if WITH_XLA_SUPPORT:
|
||||
|
@ -2044,7 +2044,7 @@ The attr `block_size` indicates the input block size and how the data is moved.
|
||||
|
||||
* Chunks of data of size `block_size * block_size` from depth are rearranged
|
||||
into non-overlapping blocks of size `block_size x block_size`
|
||||
* The width the output tensor is `input_width * block_size`, whereas the
|
||||
* The width the output tensor is `input_depth * block_size`, whereas the
|
||||
height is `input_height * block_size`.
|
||||
* The depth of the input tensor must be divisible by
|
||||
`block_size * block_size`.
|
||||
|
@ -122,6 +122,7 @@ specified then `scale += IdentityMatrix`. Otherwise specifying a
|
||||
`scale_diag` has shape [N1, N2, ... k, k], which represents a k x k
|
||||
lower triangular matrix.
|
||||
When `None` no `scale_tril` term is added to `scale`.
|
||||
The upper triangular elements above the diagonal are ignored.
|
||||
* <b>`scale_perturb_factor`</b>: Numeric `Tensor` representing factor matrix with
|
||||
last two dimensions of shape `(k, r)`.
|
||||
When `None`, no rank-r update is added to `scale`.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user