Merge pull request #3082 from rmlarsen/branch_126082003

Branch 126082003
This commit is contained in:
Vijay Vasudevan 2016-06-28 13:01:49 -07:00 committed by GitHub
commit cee7cdd23d
128 changed files with 5433 additions and 1417 deletions

View File

@ -1,7 +1,6 @@
package(default_visibility = ["//visibility:public"])
archive_dir = "eigen-eigen-802d984ade26"
archive_dir = "eigen-eigen-334b1d428283"
cc_library(
name = "eigen",
hdrs = glob([archive_dir+"/**/*.h", archive_dir+"/unsupported/Eigen/CXX11/*", archive_dir+"/Eigen/*"]),

View File

@ -8,7 +8,7 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "bayesflow_py",
@ -16,7 +16,7 @@ py_library(
srcs_version = "PY2AND3",
)
cuda_py_tests(
cuda_py_test(
name = "stochastic_graph_test",
size = "small",
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
@ -27,6 +27,17 @@ cuda_py_tests(
],
)
cuda_py_test(
name = "reinforce_simple_example",
size = "small",
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,143 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple examples of the REINFORCE algorithm."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
distributions = tf.contrib.distributions
sg = tf.contrib.bayesflow.stochastic_graph
def split_apply_merge(inp, partitions, fns):
"""Split input according to partitions. Pass results through fns and merge.
Args:
inp: the input vector
partitions: tensor of same length as input vector, having values 0, 1
fns: the two functions.
Returns:
the vector routed, where routed[i] = fns[partitions[i]](inp[i])
"""
new_inputs = tf.dynamic_partition(inp, partitions, len(fns))
new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)]
new_indices = tf.dynamic_partition(
tf.range(0, inp.get_shape()[0]), partitions, len(fns))
return tf.dynamic_stitch(new_indices, new_outputs)
def plus_1(inputs):
return inputs + 1.0
def minus_1(inputs):
return inputs - 1.0
def build_split_apply_merge_model():
"""Build the Split-Apply-Merge Model.
Route each value of input [-1, -1, 1, 1] through one of the
functions, plus_1, minus_1. The decision for routing is made by
4 Bernoulli R.V.s whose parameters are determined by a neural network
applied to the input. REINFORCE is used to update the NN parameters.
Returns:
The 3-tuple (route_selection, routing_loss, final_loss), where:
- route_selection is an int 4-vector
- routing_loss is a float 4-vector
- final_loss is a float scalar.
"""
inputs = tf.constant([[-1.0], [-1.0], [1.0], [1.0]])
targets = tf.constant([[0.0], [0.0], [0.0], [0.0]])
paths = [plus_1, minus_1]
weights = tf.get_variable("w", [1, 2])
bias = tf.get_variable("b", [1, 1])
logits = tf.matmul(inputs, weights) + bias
# REINFORCE forward step
route_selection = sg.DistributionTensor(
distributions.Categorical, logits=logits)
# Accessing route_selection as a Tensor below forces a sample of
# the Categorical distribution based on its logits.
# This is equivalent to calling route_selection.value().
#
# route_selection.value() returns an int32 4-vector with random
# values in {0, 1}
# COPY+ROUTE+PASTE
outputs = split_apply_merge(inputs, route_selection, paths)
# flatten routing_loss to a row vector (from a column vector)
routing_loss = tf.reshape(tf.square(outputs - targets), shape=[-1])
# returns
# [stop_gradient(routing_loss) *
# route_selection.log_pmf(stop_gradients(route_selection.value()))],
# where log_pmf has gradients going all the way back to weights and bias.
# REINFORCE loss
score_function_losses = sg.surrogate_losses([routing_loss])
# calculate the entire loss:
# routing_loss, and the score function loss.
# in this case, the routing_loss depends on the variables only through
# "route_selection", which has a stop_gradients on it. so the
# gradient of the loss really come through the score function
all_loss = score_function_losses + [routing_loss]
final_loss = tf.reduce_sum(tf.add_n(all_loss))
return (route_selection, routing_loss, final_loss)
class REINFORCESimpleExample(tf.test.TestCase):
def testSplitApplyMerge(self):
# Repeatability. SGD has a tendency to jump around, even here.
tf.set_random_seed(1)
with self.test_session() as sess:
# Use sampling to train REINFORCE
with sg.value_type(sg.SampleAndReshapeValue(n=1)):
(route_selection,
routing_loss,
final_loss) = build_split_apply_merge_model()
sgd = tf.train.GradientDescentOptimizer(1.0).minimize(final_loss)
tf.initialize_all_variables().run()
for i in range(10):
# Run loss and inference step. This toy problem converges VERY quickly.
(routing_loss_v, final_loss_v, route_selection_v, _) = sess.run(
[routing_loss, final_loss, tf.identity(route_selection), sgd])
print(
"Iteration %d, routing loss: %s, final_loss: %s, "
"route selection: %s"
% (i, routing_loss_v, final_loss_v, route_selection_v))
self.assertAllEqual([0, 0, 1, 1], route_selection_v)
self.assertAllClose([0.0, 0.0, 0.0, 0.0], routing_loss_v)
self.assertAllClose(0.0, final_loss_v)
if __name__ == "__main__":
tf.test.main()

View File

@ -265,6 +265,16 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
self.assertEqual((), pmf_same.get_shape())
def testNonStrictTurnsOffAllChecks(self):
# Make totally invalid input.
with self.test_session():
alpha = [[-1., 2]] # alpha should be positive.
counts = [[1., 0], [0., -1]] # counts should be non-negative.
n = [-5.3] # n should be a non negative integer equal to counts.sum.
dist = tf.contrib.distributions.DirichletMultinomial(
n, alpha, strict=False)
dist.pmf(counts).eval() # Should not raise.
if __name__ == '__main__':
tf.test.main()

View File

@ -38,21 +38,6 @@ def _assert_integer_form(x):
math_ops.round(casted_x), x.dtype))
def _check_alpha(alpha):
"""Check alpha for proper shape, values, then return tensor version."""
alpha = ops.convert_to_tensor(alpha, name='alpha_before_deps')
return control_flow_ops.with_dependencies(
[check_ops.assert_rank_at_least(alpha, 1),
check_ops.assert_positive(alpha)], alpha)
def _check_n(n):
"""Check n for proper shape, values, then return tensor version."""
n = ops.convert_to_tensor(n, name='n_before_deps')
return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
def _log_combinations(n, counts, name='log_combinations'):
"""Log number of ways counts could have come in."""
# First a bit about the number of ways counts could have come in:
@ -148,6 +133,7 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
n,
alpha,
allow_arbitrary_counts=False,
allow_nan=False,
strict=True,
name='DirichletMultinomial'):
"""Initialize a batch of DirichletMultinomial distributions.
@ -163,8 +149,15 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf
allows for the `counts` tensor to be non-integral values.
The pmf/cdf are functions that can be evaluated at non-integral values,
but are only a distribution over non-negative integers.
strict: Not used (yet).
but are only a distribution over non-negative integers. If `strict` is
`False`, this assertion is turned off.
allow_nan: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
strict: Whether to assert valid values for parameters `alpha` and `n`, and
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
guaranteed.
name: The name to prefix Ops created by this distribution class.
Examples:
@ -178,9 +171,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
"""
# TODO(langmore): Should strict supercede allow_arbitrary_counts?
# Or work orthogonal to it? Implement correct usage of strict=True
self._allow_nan = allow_nan
self._strict = strict
self._name = name
self._allow_arbitrary_counts = allow_arbitrary_counts
with ops.op_scope([n, alpha], name):
# Broadcasting works because:
@ -192,12 +185,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
# explicitivity.
# * All calls involving `counts` eventually require a broadcast between
# `counts` and alpha.
self._alpha = _check_alpha(alpha)
self._name = name
n = _check_n(n)
n = math_ops.cast(n, self._alpha.dtype)
self._n = n
self._alpha = self._check_alpha(alpha)
n = self._check_n(n)
self._n = math_ops.cast(n, self._alpha.dtype)
self._alpha_sum = math_ops.reduce_sum(
self._alpha, reduction_indices=[-1], keep_dims=False)
@ -217,6 +207,11 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
"""Parameter defining this distribution."""
return self._alpha
@property
def allow_nan(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan
@property
def strict(self):
"""Boolean describing behavior on invalid input."""
@ -368,7 +363,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name='counts_before_deps')
counts = ops.convert_to_tensor(counts, name='counts')
if not self.strict:
return counts
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
dependencies = [check_ops.assert_non_negative(counts),
check_ops.assert_equal(self._n,
@ -378,3 +375,18 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
dependencies += [_assert_integer_form(counts)]
return control_flow_ops.with_dependencies(dependencies, counts)
def _check_alpha(self, alpha):
alpha = ops.convert_to_tensor(alpha, name='alpha')
if not self.strict:
return alpha
return control_flow_ops.with_dependencies(
[check_ops.assert_rank_at_least(alpha, 1),
check_ops.assert_positive(alpha)], alpha)
def _check_n(self, n):
n = ops.convert_to_tensor(n, name='n')
if not self.strict:
return n
return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)

View File

@ -51,11 +51,11 @@ def stratified_sample(data, labels, probs, batch_size,
threads_per_queue: Number of threads for each per-class queue.
name: Optional prefix for ops created by this function.
Raises:
AssertionError: enqueue_many is True and labels doesn't have a batch
ValueError: enqueue_many is True and labels doesn't have a batch
dimension, or if enqueue_many is False and labels isn't a scalar.
AssertionError: enqueue_many is True, and batch dimension on data and labels
ValueError: enqueue_many is True, and batch dimension of data and labels
don't match.
AssertionError: if probs don't sum to one.
ValueError: if probs don't sum to one.
TFAssertion: if labels aren't integers in [0, num classes).
Returns:
(data_batch, label_batch)
@ -81,11 +81,11 @@ def stratified_sample(data, labels, probs, batch_size,
labels = array_ops.expand_dims(labels, 0)
# Validate that input is consistent.
data, labels = _verify_input(data, labels, probs)
data, labels, probs = _verify_input(data, labels, probs)
# Make per-class queues.
per_class_queues = _make_per_class_queues(
data, labels, len(probs), queue_capacity, threads_per_queue)
data, labels, probs.size, queue_capacity, threads_per_queue)
# Use the per-class queues to generate stratified batches.
return _get_batch(per_class_queues, probs, batch_size)
@ -93,15 +93,19 @@ def stratified_sample(data, labels, probs, batch_size,
def _verify_input(data, labels, probs):
"""Verify that batched inputs are well-formed."""
# Probabilities must be a numpy array or a Python list.
if not (isinstance(probs, np.ndarray) or isinstance(probs, list)):
raise ValueError('Probabilities must be python or numpy array')
# Probabilities must be able to be converted to a 1D non-object numpy array.
probs = np.asarray(probs)
if probs.dtype == np.dtype('object'):
raise ValueError('Probabilities must be able to be converted to a numpy '
'array.')
if len(probs.shape) != 1:
raise ValueError('Probabilities must be 1D.')
# Probabilities must sum to one.
# TODO(joelshor): Investigate whether logits should be passed instead of
# probs.
if np.sum(probs) != 1.0:
raise ValueError('Probabilities must sum to one.')
if not np.isclose(np.sum(probs), 1.0):
raise ValueError('Probabilities must sum to one.', np.sum(probs))
# Labels tensor should only have batch dimension.
labels.get_shape().assert_has_rank(1)
@ -129,7 +133,7 @@ def _verify_input(data, labels, probs):
check_ops.assert_less(labels, math_ops.cast(len(probs), labels.dtype))],
labels)
return data, labels
return data, labels, probs
def _make_per_class_queues(data, labels, num_classes, queue_capacity,
@ -161,12 +165,12 @@ def _make_per_class_queues(data, labels, num_classes, queue_capacity,
def _get_batch(per_class_queues, probs, batch_size):
"""Generates batches according to per-class-probabilities."""
num_classes = len(probs)
num_classes = probs.size
# Number of examples per class is governed by a multinomial distribution.
# Note: multinomial takes unnormalized log probabilities for its first
# argument.
# argument, of dimension [batch_size, num_classes].
examples = random_ops.multinomial(
math_ops.log([[float(x) for x in probs]]), batch_size)
np.expand_dims(np.log(probs), 0), batch_size)
# Prepare the data and label batches.
val_list = []

View File

@ -63,6 +63,11 @@ class SamplingOpsTest(tf.test.TestCase):
tf.contrib.framework.sampling_ops.stratified_sample(
val, label, np.array([.1] * 5), batch_size)
# Probabilities must be 1D.
with self.assertRaises(ValueError):
tf.contrib.framework.sampling_ops.stratified_sample(
val, label, np.array([[.25, .25], [.25, .25]]), batch_size)
def testRuntimeAssertionFailures(self):
probs = [.2] * 5
vals = tf.zeros([3, 1])
@ -75,14 +80,14 @@ class SamplingOpsTest(tf.test.TestCase):
# Set up graph with illegal label vector.
label_ph = tf.placeholder(tf.int32, shape=[None])
batch_tf = tf.contrib.framework.sampling_ops._verify_input(
vals_tf, lbls_tf, _ = tf.contrib.framework.sampling_ops._verify_input(
vals, label_ph, probs)
for illegal_label in illegal_labels:
# Run session that should fail.
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(batch_tf, feed_dict={label_ph: illegal_label})
sess.run([vals_tf, lbls_tf], feed_dict={label_ph: illegal_label})
def testBatchingBehavior(self):
batch_size = 20
@ -119,13 +124,14 @@ class SamplingOpsTest(tf.test.TestCase):
# Set up graph with placeholders.
vals_ph = tf.placeholder(tf.float32) # completely undefined shape
labels_ph = tf.placeholder(tf.int32) # completely undefined shape
batch_tf = tf.contrib.framework.sampling_ops._verify_input(
vals_tf, lbls_tf, _ = tf.contrib.framework.sampling_ops._verify_input(
vals_ph, labels_ph, probs)
# Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs:
with self.test_session() as sess:
sess.run(batch_tf, feed_dict={vals_ph: vals, labels_ph: labels})
sess.run([vals_tf, lbls_tf], feed_dict={vals_ph: vals,
labels_ph: labels})
def testNormalBehavior(self):
# Set up graph.

View File

@ -18,6 +18,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn/python/learn/datasets",
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/python:framework",
],
)
@ -220,6 +221,18 @@ py_test(
],
)
py_test(
name = "compare_test",
size = "small",
srcs = ["python/learn/tests/dataframe/compare_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "early_stopping_test",
size = "medium",
@ -562,6 +575,19 @@ py_test(
],
)
py_test(
name = "export_test",
size = "small",
srcs = ["python/learn/utils/export_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "stability_test",
size = "small",

View File

@ -32,5 +32,17 @@ from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.sum import Sum
# pylint: disable=g-import-not-at-top,g-bad-import-order
# Unary Transform registration
from tensorflow.contrib.learn.python.learn.dataframe.transforms import unary_transforms as _ut
for ut_def in _ut.UNARY_TRANSFORMS:
_ut.register_unary_op(*ut_def)
# Comparison Transform registration
from tensorflow.contrib.learn.python.learn.dataframe.transforms import compare as _cmp
for ct_def in _cmp.COMPARISON_TRANSFORMS:
_cmp.register_comparison_ops(*ct_def)
__all__ = ['DataFrame', 'Series', 'TransformedSeries', 'TensorFlowDataFrame',
'parameter', 'Transform']

View File

@ -0,0 +1,156 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transforms for comparing pairs of `Series`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.dataframe import series
from tensorflow.contrib.learn.python.learn.dataframe import transform
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
# Each entry is a mapping from registered_name to operation. Each operation is
# wrapped in a transform and then registered as a member function
# `Series`.registered_name().
COMPARISON_TRANSFORMS = [("__eq__", math_ops.equal),
("__gt__", math_ops.greater),
("__ge__", math_ops.greater_equal),
("__lt__", math_ops.less),
("__le__", math_ops.less_equal)]
SERIES_DOC_FORMAT_STRING = (
"A `Transform` that uses `{0}` to compare two Series. "
"Documentation for `{0}`: \n\n {1}"
)
SCALAR_DOC_FORMAT_STRING = (
"A `Transform` that uses `{0}` to compare a Series and a scalar. "
"Documentation for `{0}`: \n\n {1}"
)
class SeriesComparisonTransform(transform.Transform):
"""Parent class for `Transform`s that compare `Series` elementwise."""
@property
def input_valency(self):
return 2
@property
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
# TODO(jamieas): consider supporting sparse comparisons.
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
input_tensors[1], ops.SparseTensor):
raise TypeError("{} does not support SparseTensors".format(type(
self).__name__))
# pylint: disable=not-callable
return self.return_type(self._compare(input_tensors[0], input_tensors[1]))
class ScalarComparisonTransform(transform.Transform):
"""Parent class for `Transform`s that compare `Series` to a scalar."""
def __init__(self, threshold):
if isinstance(threshold, series.Series):
raise ValueError(
"{} is used to compare Series with scalars. It was called with "
"another Series.".format(
type(self).__name__))
super(ScalarComparisonTransform, self).__init__()
self._threshold = threshold
@transform.parameter
def threshold(self):
return self._threshold
@property
def input_valency(self):
return 1
@property
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices,
self._compare(input_tensor.values),
input_tensor.shape)
else:
result = self._compare(input_tensor)
# pylint: disable=not-callable
return self.return_type(result)
# pylint: disable=unused-argument
def register_comparison_ops(method_name, operation):
"""Registers `Series` member functions for comparisons.
Args:
method_name: the name of the method that will be created in `Series`.
operation: TensorFlow operation used for comparison.
"""
# Define series-series comparison `Transform`.
@property
def series_name(self):
return operation.__name__
series_doc = SERIES_DOC_FORMAT_STRING.format(operation.__name__,
operation.__doc__)
def series_compare(self, x, y):
return operation(x, y)
series_transform_cls = type("scalar_{}".format(operation.__name__),
(SeriesComparisonTransform,),
{"name": series_name,
"__doc__": series_doc,
"_compare": series_compare})
# Define series-scalar comparison `Transform`.
@property
def scalar_name(self):
return "scalar_{}".format(operation.__name__)
scalar_doc = SCALAR_DOC_FORMAT_STRING.format(operation.__name__,
operation.__doc__)
def scalar_compare(self, x):
return operation(x, self.threshold)
scalar_transform_cls = type("scalar_{}".format(operation.__name__),
(ScalarComparisonTransform,),
{"name": scalar_name,
"__doc__": scalar_doc,
"_compare": scalar_compare})
# Define function that delegates to the two `Transforms`.
def _comparison_fn(self, other, *args, **kwargs):
# pylint: disable=not-callable,abstract-class-instantiated
if isinstance(other, series.Series):
return series_transform_cls(*args, **kwargs)([self, other])[0]
return scalar_transform_cls(other, *args, **kwargs)([self])[0]
# Register new member function of `Series`.
setattr(series.Series, method_name, _comparison_fn)

View File

@ -52,9 +52,7 @@ DOC_FORMAT_STRING = (
# pylint: disable=unused-argument
def _register_unary_op(registered_name,
operation):
def register_unary_op(registered_name, operation):
"""Creates a `Transform` that wraps a unary tensorflow operation.
If `registered_name` is specified, the `Transform` is registered as a member
@ -100,7 +98,3 @@ def _register_unary_op(registered_name,
"_apply_transform": _apply_transform})
series.Series.register_unary_op(registered_name)(cls)
for ut in UNARY_TRANSFORMS:
_register_unary_op(*ut)

View File

@ -34,7 +34,6 @@ from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
@ -47,6 +46,238 @@ from tensorflow.python.ops import variables
from tensorflow.python.training import training
class _ComposableModel(object):
"""ABC for building blocks that can be used to create estimators.
Subclasses need to implement the following methods:
- build_model
- _get_optimizer
See below for the required signatures.
_ComposableModel and its subclasses are not part of the public tf.learn API.
"""
def __init__(self,
num_label_columns,
optimizer,
weight_collection_name,
gradient_clip_norm):
"""Common initialization for all _ComposableModel objects.
Args:
num_label_columns: The number of label/target columns.
optimizer: An instance of `tf.Optimizer` used to apply gradients to
the model. If `None`, will use a FTRL optimizer.
weight_collection_name: A string defining the name to use for the
collection of weights (e.g. 'dnn').
gradient_clip_norm: A float > 0. If provided, gradients are clipped
to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details.
"""
self._num_label_columns = num_label_columns
self._optimizer = optimizer
self._weight_collection_name = weight_collection_name
self._gradient_clip_norm = gradient_clip_norm
self._feature_columns=None
def build_model(self, features, feature_columns, is_training):
"""Builds the model that can calculate the logits.
Args:
features: A mapping from feature columns to tensors.
feature_columns: An iterable containing all the feature columns used
by the model. All items in the set should be instances of
classes derived from `FeatureColumn`.
is_training: Set to True when training, False otherwise.
Returns:
The logits for this model.
"""
raise NotImplementedError
def get_train_step(self, loss):
"""Returns the ops to run to perform a training step on this estimator.
Args:
loss: The loss to use when calculating gradients.
Returns:
The ops to run to perform a training step.
"""
my_vars = self._get_vars()
if not (self._get_feature_columns() or my_vars):
return []
grads = gradients.gradients(loss, my_vars)
if self._gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
self._optimizer = self._get_optimizer()
return [self._optimizer.apply_gradients(zip(grads, my_vars))]
def _get_feature_columns(self):
if not self._feature_columns:
return None
feature_column_ops.check_feature_columns(self._feature_columns)
return sorted(set(self._feature_columns), key=lambda x: x.key)
def _get_feature_dict(self, features):
if isinstance(features, dict):
return features
return {"": features}
def _get_vars(self):
if self._get_feature_columns():
return ops.get_collection(self._weight_collection_name)
return []
def _get_optimizer(self):
raise NotImplementedError
class _LinearComposableModel(_ComposableModel):
"""A _ComposableModel that implements linear regression.
Instances of this class can be used to build estimators through the use
of composition.
"""
def __init__(self,
num_label_columns,
optimizer=None,
gradient_clip_norm=None):
"""Initializes _LinearComposableModel objects.
Args:
num_label_columns: The number of label/target columns.
optimizer: An instance of `tf.Optimizer` used to apply gradients to
the model. If `None`, will use a FTRL optimizer.
gradient_clip_norm: A float > 0. If provided, gradients are clipped
to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details.
"""
super(_LinearComposableModel, self).__init__(
num_label_columns=num_label_columns,
optimizer=optimizer,
weight_collection_name="linear",
gradient_clip_norm=gradient_clip_norm)
def build_model(self, features, feature_columns, is_training):
"""See base class."""
features = self._get_feature_dict(features)
self._feature_columns = feature_columns
logits, _, _ = layers.weighted_sum_from_feature_columns(
columns_to_tensors=features,
feature_columns=self._get_feature_columns(),
num_outputs=self._num_label_columns,
weight_collections=[self._weight_collection_name],
name="linear")
return logits
def _get_optimizer(self):
if self._optimizer is None:
self._optimizer = "Ftrl"
if isinstance(self._optimizer, six.string_types):
default_learning_rate = 1. / math.sqrt(len(self._get_feature_columns()))
self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
learning_rate=default_learning_rate)
return self._optimizer
class _DNNComposableModel(_ComposableModel):
"""A _ComposableModel that implements a DNN.
Instances of this class can be used to build estimators through the use
of composition.
"""
def __init__(self,
num_label_columns,
hidden_units,
optimizer=None,
activation_fn=nn.relu,
dropout=None,
gradient_clip_norm=None,
config=None):
"""Initializes _DNNComposableModel objects.
Args:
num_label_columns: The number of label/target columns.
hidden_units: List of hidden units per layer. All layers are fully
connected.
optimizer: An instance of `tf.Optimizer` used to apply gradients to
the model. If `None`, will use a FTRL optimizer.
activation_fn: Activation function applied to each layer. If `None`,
will use `tf.nn.relu`.
dropout: When not None, the probability we will drop out
a given coordinate.
gradient_clip_norm: A float > 0. If provided, gradients are clipped
to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details.
config: RunConfig object to configure the runtime settings.
"""
super(_DNNComposableModel, self).__init__(
num_label_columns=num_label_columns,
optimizer=optimizer,
weight_collection_name="DNN",
gradient_clip_norm=gradient_clip_norm)
self._hidden_units = hidden_units
self._activation_fn = activation_fn
self._dropout = dropout
self._config = config
def _add_hidden_layer_summary(self, value, tag):
# TODO(zakaria): Move this code to tf.learn and add test.
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
nn.zero_fraction(value))
logging_ops.histogram_summary("%s:activation" % tag, value)
def build_model(self, features, feature_columns, is_training):
"""See base class."""
features = self._get_feature_dict(features)
self._feature_columns = feature_columns
net = layers.input_from_feature_columns(
features,
self._get_feature_columns(),
weight_collections=[self._weight_collection_name])
for layer_id, num_hidden_units in enumerate(self._hidden_units):
with variable_scope.variable_op_scope(
[net], "hiddenlayer_%d" % layer_id,
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)) as scope:
net = layers.fully_connected(
net,
num_hidden_units,
activation_fn=self._activation_fn,
variables_collections=[self._weight_collection_name],
scope=scope)
if self._dropout is not None and is_training:
net = layers.dropout(
net,
keep_prob=(1.0 - self._dropout))
self._add_hidden_layer_summary(net, scope.name)
with variable_scope.variable_op_scope(
[net], "dnn_logits",
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)) as scope:
logits = layers.fully_connected(
net,
self._num_label_columns,
activation_fn=None,
variables_collections=[self._weight_collection_name],
scope=scope)
self._add_hidden_layer_summary(logits, "dnn_logits")
return logits
def _get_optimizer(self):
if self._optimizer is None:
self._optimizer = "Adagrad"
if isinstance(self._optimizer, six.string_types):
self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
learning_rate=0.05)
return self._optimizer
# TODO(ispir): Increase test coverage
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
"""An estimator for TensorFlow Linear and DNN joined training models.
@ -110,6 +341,21 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
"""
super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir,
config=config)
self._linear_model = _LinearComposableModel(
num_label_columns=target_column.num_label_columns,
optimizer=linear_optimizer,
gradient_clip_norm=gradient_clip_norm)
self._dnn_model = _DNNComposableModel(
num_label_columns=target_column.num_label_columns,
hidden_units=dnn_hidden_units,
optimizer=dnn_optimizer,
activation_fn=dnn_activation_fn,
dropout=dnn_dropout,
gradient_clip_norm=gradient_clip_norm,
config=self._config) if dnn_hidden_units else None
self._linear_feature_columns = linear_feature_columns
self._linear_optimizer = linear_optimizer
self._dnn_feature_columns = dnn_feature_columns
@ -175,22 +421,11 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
loss = self._loss(logits, targets, features)
logging_ops.scalar_summary("loss", loss)
linear_vars = self._get_linear_vars()
dnn_vars = self._get_dnn_vars()
grads = gradients.gradients(loss, dnn_vars + linear_vars)
if self._gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads,
self._gradient_clip_norm)
linear_train_step = self._linear_model.get_train_step(loss)
dnn_train_step = (self._dnn_model.get_train_step(loss)
if self._dnn_model else [])
dnn_grads = grads[0:len(dnn_vars)]
linear_grads = grads[len(dnn_vars):]
train_ops = self._get_linear_training_ops(
linear_grads, linear_vars) + self._get_dnn_training_ops(dnn_grads,
dnn_vars)
train_step = control_flow_ops.group(*train_ops, name="combined_training_op")
with ops.control_dependencies([train_step]):
with ops.control_dependencies(linear_train_step + dnn_train_step):
with ops.get_default_graph().colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op, loss
@ -232,54 +467,13 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
feature_column_ops.check_feature_columns(self._dnn_feature_columns)
return sorted(set(self._dnn_feature_columns), key=lambda x: x.key)
def _dnn_logits(self, features, is_training=False):
net = layers.input_from_feature_columns(
features,
self._get_dnn_feature_columns(),
weight_collections=[self._dnn_weight_collection])
for layer_id, num_hidden_units in enumerate(self._dnn_hidden_units):
with variable_scope.variable_op_scope(
[net], "hiddenlayer_%d" % layer_id,
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)) as scope:
net = layers.fully_connected(
net,
num_hidden_units,
activation_fn=self._dnn_activation_fn,
variables_collections=[self._dnn_weight_collection],
scope=scope)
if self._dnn_dropout is not None and is_training:
net = layers.dropout(
net,
keep_prob=(1.0 - self._dnn_dropout))
self._add_hidden_layer_summary(net, scope.name)
with variable_scope.variable_op_scope(
[net], "dnn_logit",
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)) as scope:
logit = layers.fully_connected(
net,
self._target_column.num_label_columns,
activation_fn=None,
variables_collections=[self._dnn_weight_collection],
scope=scope)
self._add_hidden_layer_summary(logit, "dnn_logit")
return logit
def _dnn_logits(self, features, is_training):
return self._dnn_model.build_model(
features, self._dnn_feature_columns, is_training)
def _add_hidden_layer_summary(self, value, tag):
# TODO(zakaria): Move this code to tf.learn and add test.
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
nn.zero_fraction(value))
logging_ops.histogram_summary("%s:activation" % tag, value)
def _linear_logits(self, features):
logits, _, _ = layers.weighted_sum_from_feature_columns(
columns_to_tensors=features,
feature_columns=self._get_linear_feature_columns(),
num_outputs=self._target_column.num_label_columns,
weight_collections=[self._linear_weight_collection],
name="linear")
return logits
def _linear_logits(self, features, is_training):
return self._linear_model.build_model(
features, self._linear_feature_columns, is_training)
def _get_feature_dict(self, features):
if isinstance(features, dict):
@ -318,12 +512,12 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
features = self._get_feature_dict(features)
if linear_feature_columns and dnn_feature_columns:
logits = (self._linear_logits(features) +
self._dnn_logits(features, is_training=is_training))
logits = (self._linear_logits(features, is_training) +
self._dnn_logits(features, is_training))
elif dnn_feature_columns:
logits = self._dnn_logits(features, is_training=is_training)
logits = self._dnn_logits(features, is_training)
else:
logits = self._linear_logits(features)
logits = self._linear_logits(features, is_training)
if self._enable_centered_bias:
return nn.bias_add(logits, self._centered_bias())

View File

@ -250,13 +250,22 @@ class BaseEstimator(sklearn.BaseEstimator):
name=None):
"""Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
Args:
x: features.
y: targets.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
feed_fn: Function creating a feed dict every time it is called. Called
once per iteration.
batch_size: minibatch size to use on the input, defaults to first
@ -290,11 +299,14 @@ class BaseEstimator(sklearn.BaseEstimator):
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
return self._evaluate_model(input_fn=input_fn,
feed_fn=feed_fn,
steps=steps,
metrics=metrics,
name=name)
eval_results, global_step = self._evaluate_model(input_fn=input_fn,
feed_fn=feed_fn,
steps=steps,
metrics=metrics,
name=name)
if eval_results is not None:
eval_results.update({'global_step': global_step})
return eval_results
def predict(self, x=None, input_fn=None, batch_size=None, outputs=None):
"""Returns predictions for given features.
@ -477,15 +489,16 @@ class BaseEstimator(sklearn.BaseEstimator):
# Add default monitors.
if monitors is None:
monitors = []
monitors += monitors_lib.get_default_monitors(
loss_op=loss_op,
summary_op=logging_ops.get_summary_op(),
save_summary_steps=self._config.save_summary_steps,
summary_writer=graph_actions.get_summary_writer(self._model_dir))
is_chief = self._config.task == 0
if not is_chief:
# Run monitors only on chief.
if is_chief:
monitors += monitors_lib.get_default_monitors(
loss_op=loss_op,
summary_op=logging_ops.get_summary_op(),
save_summary_steps=self._config.save_summary_steps,
summary_writer=graph_actions.get_summary_writer(self._model_dir))
else:
monitors = []
# Setup monitors.
@ -545,7 +558,7 @@ class BaseEstimator(sklearn.BaseEstimator):
# TODO(wicke): Remove this once Model and associated code are gone.
if (hasattr(self._config, 'execution_mode') and
self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
return
return None, None
# Check that model has been trained.
checkpoint_path = self._model_dir
@ -564,7 +577,7 @@ class BaseEstimator(sklearn.BaseEstimator):
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets, metrics)
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, _ = graph_actions.evaluate(
eval_results, current_global_step = graph_actions.evaluate(
graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
@ -574,7 +587,8 @@ class BaseEstimator(sklearn.BaseEstimator):
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results
return eval_results, current_global_step
def _get_features_from_input_fn(self, input_fn):
result = input_fn()

View File

@ -240,6 +240,8 @@ class EstimatorTest(tf.test.TestCase):
predictions = est.predict(x=boston.data)
other_score = _sklearn.mean_squared_error(predictions, boston.target)
self.assertAllClose(other_score, scores['MSE'])
self.assertTrue('global_step' in scores)
self.assertEqual(scores['global_step'], 100)
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
@ -257,6 +259,8 @@ class EstimatorTest(tf.test.TestCase):
axis=1))
other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
self.assertAllClose(other_score, scores['accuracy'])
self.assertTrue('global_step' in scores)
self.assertEqual(scores['global_step'], 100)
def testIrisInputFn(self):
iris = tf.contrib.learn.datasets.load_iris()

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
@ -520,3 +521,35 @@ class GraphDump(BaseMonitor):
else:
matched.append(key)
return matched, non_matched
class ExportMonitor(EveryN):
"""Monitor that exports Estimator every N steps."""
def __init__(self, every_n_steps, export_dir, exports_to_keep=5):
"""Initializes ExportMonitor.
Args:
every_n_steps: Run monitor every N steps.
export_dir: str, fodler to export.
exports_to_keep: int, number of exports to keep.
"""
super(ExportMonitor, self).__init__(every_n_steps=every_n_steps)
self.export_dir = export_dir
self.exports_to_keep = exports_to_keep
def every_n_step_end(self, step, outputs):
super(ExportMonitor, self).every_n_step_end(step, outputs)
try:
export.export_estimator(self._estimator, self.export_dir,
exports_to_keep=self.exports_to_keep)
except RuntimeError:
# Currently we are not syncronized with saving checkpoints, which leads to
# runtime errors when we are calling export on the same global step.
logging.info("Skipping exporting for the same step. "
"Consider exporting less frequently.")
def end(self):
super(ExportMonitor, self).end()
export.export_estimator(self._estimator, self.export_dir,
exports_to_keep=self.exports_to_keep)

View File

@ -0,0 +1,78 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for comparison transforms."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.dataframe import tensorflow_dataframe as df
from tensorflow.contrib.learn.python.learn.dataframe.transforms.compare import COMPARISON_TRANSFORMS
NUMPY_ARRAY_SIZE = 100
SCALAR = 50
class CompareTestCase(tf.test.TestCase):
"""Test class for comparison transforms."""
@classmethod
def add_test_case(cls, fn_name, op):
def _test(self):
rng = np.arange(-NUMPY_ARRAY_SIZE // 2,
NUMPY_ARRAY_SIZE // 2,
dtype="float32")
frame = df.TensorFlowDataFrame.from_numpy(rng,
batch_size=len(rng),
shuffle=False)
frame["sqr"] = frame["value"].square()
self.assertTrue(hasattr(frame["value"], fn_name))
frame["series_result"] = getattr(frame["value"],
fn_name)(frame["sqr"])
frame["scalar_result"] = getattr(frame["value"], fn_name)(SCALAR)
frame_built = frame.build()
expected_series_tensor = op(frame_built["value"], frame_built["sqr"])
actual_series_tensor = frame_built["series_result"]
expected_scalar_tensor = op(frame_built["value"], SCALAR)
actual_scalar_tensor = frame_built["scalar_result"]
session = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=session, coord=coord)
actual_series, expected_series, actual_scalar, expected_scalar = (
session.run([actual_series_tensor, expected_series_tensor,
actual_scalar_tensor, expected_scalar_tensor]))
coord.request_stop()
coord.join(threads)
np.testing.assert_almost_equal(expected_series, actual_series)
np.testing.assert_almost_equal(expected_scalar, actual_scalar)
setattr(cls, "test{}".format(op.__name__), _test)
for ct in COMPARISON_TRANSFORMS:
CompareTestCase.add_test_case(*ct)
if __name__ == "__main__":
tf.test.main()

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import csv
import math
import tempfile
import numpy as np
@ -62,13 +63,14 @@ def _assert_df_equals_dict(expected_df, actual_dict):
def _make_test_csv():
f = tempfile.NamedTemporaryFile(delete=False, mode="w")
w = csv.writer(f)
w.writerow(["int", "float", "bool"])
w.writerow(["int", "float", "bool", "string"])
for _ in range(100):
intvalue = np.random.randint(-10, 10)
floatvalue = np.random.rand()
boolvalue = int(np.random.rand() > 0.3)
stringvalue = "S: %.4f" % np.random.rand()
row = [intvalue, floatvalue, boolvalue]
row = [intvalue, floatvalue, boolvalue, stringvalue]
w.writerow(row)
f.close()
return f.name
@ -77,14 +79,16 @@ def _make_test_csv():
def _make_test_csv_sparse():
f = tempfile.NamedTemporaryFile(delete=False, mode="w")
w = csv.writer(f)
w.writerow(["int", "float", "bool"])
w.writerow(["int", "float", "bool", "string"])
for _ in range(100):
# leave columns empty; these will be read as default value (e.g. 0 or NaN)
intvalue = np.random.randint(-10, 10) if np.random.rand() > 0.5 else ""
floatvalue = np.random.rand() if np.random.rand() > 0.5 else ""
boolvalue = int(np.random.rand() > 0.3) if np.random.rand() > 0.5 else ""
stringvalue = (("S: %.4f" % np.random.rand())
if np.random.rand() > 0.5 else "")
row = [intvalue, floatvalue, boolvalue]
row = [intvalue, floatvalue, boolvalue, stringvalue]
w.writerow(row)
f.close()
return f.name
@ -180,7 +184,7 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
enqueue_size = 7
data_path = _make_test_csv()
default_values = [0, 0.0, 0]
default_values = [0, 0.0, 0, ""]
pandas_df = pd.read_csv(data_path)
tensorflow_df = df.TensorFlowDataFrame.from_csv(
@ -200,7 +204,7 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
expected_num_batches = (num_epochs * 100) // batch_size
data_path = _make_test_csv()
default_values = [0, 0.0, 0]
default_values = [0, 0.0, 0, ""]
tensorflow_df = df.TensorFlowDataFrame.from_csv(
[data_path],
@ -221,10 +225,17 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
feature_spec = {
"int": tf.FixedLenFeature(None, dtypes.int16, np.nan),
"float": tf.VarLenFeature(dtypes.float16),
"bool": tf.VarLenFeature(dtypes.bool)
"bool": tf.VarLenFeature(dtypes.bool),
"string": tf.FixedLenFeature(None, dtypes.string, "")
}
pandas_df = pd.read_csv(data_path)
pandas_df = pd.read_csv(data_path, dtype={"string": object})
# Pandas insanely uses NaN for empty cells in a string column.
# And, we can't use Pandas replace() to fix them because nan != nan
s = pandas_df["string"]
for i in range(0, len(s)):
if isinstance(s[i], float) and math.isnan(s[i]):
s[i] = ""
tensorflow_df = df.TensorFlowDataFrame.from_csv_with_feature_spec(
[data_path],
batch_size=batch_size,

View File

@ -20,3 +20,4 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.contrib.learn.python.learn.utils.export import export_estimator

View File

@ -0,0 +1,134 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import gc
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import saver as tf_saver
def _get_first_op_from_collection(collection_name):
"""Get first element from the collection."""
elements = ops.get_collection(collection_name)
if elements is not None:
if elements:
return elements[0]
return None
def _get_saver():
"""Lazy init and return saver."""
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
if saver is not None:
if saver:
saver = saver[0]
else:
saver = None
if saver is None and variables.all_variables():
saver = tf_saver.Saver()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver
def _export_graph(graph, saver, checkpoint_path, export_dir,
default_graph_signature, named_graph_signatures,
exports_to_keep):
"""Exports graph via session_bundle, by creating a Session."""
with graph.as_default():
with tf_session.Session('') as session:
session.run(variables.initialize_local_variables())
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(session.graph.as_graph_def(),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures)
export.export(export_dir, contrib_variables.get_global_step(), session,
exports_to_keep=exports_to_keep)
def _generic_signature_fn(examples, unused_features, predictions):
"""Creates generic signature from given examples and predictions.
This is neeed for backward compatibility with default behaviour of
export_estimator.
Args:
examples: `Tensor`.
unused_features: `dict` of `Tensor`s.
predictions: `dict` of `Tensor`s.
Returns:
Tuple of default signature and named signature.
"""
tensors = {'inputs': examples}
if not isinstance(predictions, dict):
predictions = {'outputs': predictions}
tensors.update(predictions)
default_signature = exporter.generic_signature(tensors)
return default_signature, {}
# pylint: disable=protected-access
def _default_input_fn(estimator, examples):
"""Creates default input parsing using Estimator's feature signatures."""
return estimator._get_feature_ops_from_example(examples)
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
signature_fn=_generic_signature_fn, default_batch_size=1,
exports_to_keep=None):
"""Exports inference graph into given dir.
Args:
estimator: Estimator to export
export_dir: A string containing a directory to write the exported graph
and checkpoints.
input_fn: Function that given `Tensor` of `Example` strings, parses it into
features that are then passed to the model.
signature_fn: Function that given `Tensor` of `Example` strings,
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
and returns default and named exporting signautres.
default_batch_size: Default batch size of the `Example` placeholder.
exports_to_keep: Number of exports to keep.
"""
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
default_signature, named_graph_signatures = signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
# pylint: enable=protected-access

View File

@ -0,0 +1,50 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for export tools."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.contrib import learn
class ExportTest(tf.test.TestCase):
def testExportMonitor(self):
random.seed(42)
x = np.random.rand(1000)
y = 2 * x + 3
regressor = learn.LinearRegressor()
export_dir = tempfile.mkdtemp() + 'export/'
export_monitor = learn.monitors.ExportMonitor(every_n_steps=1,
export_dir=export_dir,
exports_to_keep=1)
regressor.fit(x, y, steps=10,
monitors=[export_monitor])
self.assertTrue(tf.gfile.Exists(export_dir))
self.assertFalse(tf.gfile.Exists(export_dir + '00000000/export'))
self.assertTrue(tf.gfile.Exists(export_dir + '00000010/export'))
if __name__ == '__main__':
tf.test.main()

View File

@ -50,7 +50,10 @@ py_library(
],
data = [":python/ops/_sdca_ops.so"],
srcs_version = "PY2AND3",
deps = [":sdca_ops"],
deps = [
":sdca_ops",
"//tensorflow/contrib/lookup:lookup_py",
],
)
py_test(

View File

@ -38,40 +38,15 @@ cc_test(
],
)
cc_library(
name = "resources",
srcs = ["resources.cc"],
hdrs = ["resources.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@farmhash_archive//:farmhash",
"@protobuf//:protobuf",
],
)
cc_test(
name = "resources_test",
size = "small",
srcs = ["resources_test.cc"],
deps = [
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "sdca_ops",
srcs = ["sdca_ops.cc"],
deps = [
":loss_updaters",
":resources",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/kernels:bounds_check_lib",
"//third_party/eigen3",
"@farmhash_archive//:farmhash",
"@protobuf//:protobuf",
],
alwayslink = 1,

View File

@ -1,86 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
DataByExample::DataByExample(const string& container, const string& solver_uuid)
: container_(container), solver_uuid_(solver_uuid) {}
DataByExample::~DataByExample() {}
// static
DataByExample::EphemeralKey DataByExample::MakeKey(const string& example_id) {
return Fingerprint128(example_id);
}
DataByExample::Data DataByExample::Get(const EphemeralKey& key) {
mutex_lock l(mu_);
return data_by_key_[key];
}
void DataByExample::Set(const EphemeralKey& key, const Data& data) {
mutex_lock l(mu_);
data_by_key_[key] = data;
}
Status DataByExample::Visit(
std::function<void(const Data& data)> visitor) const {
struct State {
// Snapshoted size of data_by_key_.
size_t size;
// Number of elements visited so far.
size_t num_visited = 0;
// Current element.
DataByKey::const_iterator it;
};
auto state = [this] {
mutex_lock l(mu_);
State result;
result.size = data_by_key_.size();
result.it = data_by_key_.cbegin();
return result;
}();
while (state.num_visited < state.size) {
mutex_lock l(mu_);
// Since DataByExample is modify-or-append only, a visit will (continue to)
// be successful if and only if the size of the backing store hasn't
// changed (since the body of this while-loop is under lock).
if (data_by_key_.size() != state.size) {
return errors::Unavailable("The number of elements for ", solver_uuid_,
" has changed which nullifies a visit.");
}
for (size_t i = 0; i < kVisitChunkSize && state.num_visited < state.size;
++i, ++state.num_visited, ++state.it) {
visitor(state.it->second);
}
}
return Status::OK();
}
string DataByExample::DebugString() {
return strings::StrCat("DataByExample(", container_, ", ", solver_uuid_, ")");
}
} // namespace tensorflow

View File

@ -1,108 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
#include <cstddef>
#include <functional>
#include <string>
#include <unordered_map>
#include <utility>
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Resource for storing per-example data across many sessions. The data is
// operated on in a modify or append fashion (data can be modified or added, but
// never deleted).
//
// This class is thread-safe.
class DataByExample : public ResourceBase {
public:
// The container and solver_uuid are only used for debugging purposes.
DataByExample(const string& container, const string& solver_uuid);
virtual ~DataByExample();
// Platform independent, compact and unique (with very high probability)
// representation of an example id. 'Ephemeral' because it shouldn't be put
// in persistent storage, as its implementation may change in the future.
//
// The current probability of at least one collision for 1B example_ids is
// approximately 10^-21 (ie 2^60 / 2^129).
using EphemeralKey = Fprint128;
// Makes a key for the supplied example_id, for compact storage.
static EphemeralKey MakeKey(const string& example_id);
struct Data {
float dual = 0;
float primal_loss = 0;
float dual_loss = 0;
float example_weight = 0;
};
// Accessor and mutator for the entry at Key. Accessor creates an entry with
// default value (default constructed object) if the key is not present and
// returns it.
Data Get(const EphemeralKey& key) LOCKS_EXCLUDED(mu_);
void Set(const EphemeralKey& key, const Data& data) LOCKS_EXCLUDED(mu_);
// Visits all elements in this resource. The view of each element (Data) is
// atomic, but the entirety of the visit is not (ie the visitor might see
// different versions of the Data across elements).
//
// Returns OK on success or UNAVAILABLE if the number of elements in this
// container has changed since the beginning of the visit (in which case the
// visit cannot be completed and is aborted early, and computation can be
// restarted).
Status Visit(std::function<void(const Data& data)> visitor) const
LOCKS_EXCLUDED(mu_);
string DebugString() override;
private:
// Backing container.
//
// sizeof(EntryPayload) =
// sizeof(Key) + sizeof(Data) =
// 16 + 16 = 32.
//
// So on average we use ~51.5 (32 + 19.5) bytes per entry in this table.
using EphemeralKeyHasher = Fprint128Hasher;
using DataByKey = std::unordered_map<EphemeralKey, Data, EphemeralKeyHasher>;
// TODO(sibyl-Mooth6ku): Benchmark and/or optimize this.
static const size_t kVisitChunkSize = 100;
const string container_;
const string solver_uuid_;
// TODO(sibyl-Mooth6ku): Come up with a more efficient locking scheme.
mutable mutex mu_;
DataByKey data_by_key_ GUARDED_BY(mu_);
friend class DataByExampleTest;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_

View File

@ -1,184 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
// Operators for testing convenience (for EQ and NE GUnit macros).
bool operator==(const DataByExample::Data& lhs,
const DataByExample::Data& rhs) {
return lhs.dual == rhs.dual && //
lhs.primal_loss == rhs.primal_loss && //
lhs.dual_loss == rhs.dual_loss && //
lhs.example_weight == rhs.example_weight;
}
bool operator!=(const DataByExample::Data& lhs,
const DataByExample::Data& rhs) {
return !(lhs == rhs);
}
class DataByExampleTest : public ::testing::Test {
protected:
void SetUp() override {
const string solver_uuid = "TheSolver";
ASSERT_TRUE(resource_manager_
.LookupOrCreate<DataByExample>(
container_, solver_uuid, &data_by_example_,
[&, this](DataByExample** ret) {
*ret = new DataByExample(container_, solver_uuid);
return Status::OK();
})
.ok());
}
void TearDown() override {
data_by_example_->Unref();
ASSERT_TRUE(resource_manager_.Cleanup(container_).ok());
}
// Accessors and mutators to private members of DataByExample for better
// testing.
static size_t VisitChunkSize() { return DataByExample::kVisitChunkSize; }
void InsertReservedEntryUnlocked() NO_THREAD_SAFETY_ANALYSIS {
data_by_example_->data_by_key_[{0, 0}];
}
const string container_ = "TheContainer";
ResourceMgr resource_manager_;
DataByExample* data_by_example_ = nullptr;
};
TEST_F(DataByExampleTest, MakeKeyIsCollisionResistent) {
const DataByExample::EphemeralKey key =
DataByExample::MakeKey("TheExampleId");
EXPECT_NE(key.low64, key.high64);
}
TEST_F(DataByExampleTest, MakeKeyIsPlatformAgnostic) {
// This is one way of enforcing the platform-agnostic nature of
// DataByExample::MakeKey. Basically we are checking against exact values and
// this test could be running across different platforms.
// Note that it is fine for expected values to change in the future, if the
// implementation of MakeKey changes (ie this is *not* a frozen test).
const DataByExample::EphemeralKey key =
DataByExample::MakeKey("TheExampleId");
EXPECT_EQ(10492632643343118393ULL, key.low64);
EXPECT_EQ(1007244271654873956ULL, key.high64);
}
TEST_F(DataByExampleTest, ElementAccessAndMutation) {
const DataByExample::EphemeralKey key1 =
DataByExample::MakeKey("TheExampleId1");
EXPECT_EQ(DataByExample::Data(), data_by_example_->Get(key1));
DataByExample::Data data1;
data1.dual = 1.0f;
data_by_example_->Set(key1, data1);
EXPECT_EQ(data1, data_by_example_->Get(key1));
const DataByExample::EphemeralKey key2 =
DataByExample::MakeKey("TheExampleId2");
EXPECT_NE(data_by_example_->Get(key1), data_by_example_->Get(key2));
}
TEST_F(DataByExampleTest, VisitEmpty) {
size_t num_elements = 0;
ASSERT_TRUE(
data_by_example_
->Visit([&](const DataByExample::Data& data) { ++num_elements; })
.ok());
EXPECT_EQ(0, num_elements);
}
TEST_F(DataByExampleTest, VisitMany) {
const size_t kNumElements = 2 * VisitChunkSize() + 1;
for (size_t i = 0; i < kNumElements; ++i) {
DataByExample::Data data;
data.dual = static_cast<float>(i);
data_by_example_->Set(DataByExample::MakeKey(strings::StrCat(i)), data);
}
size_t num_elements = 0;
double total_dual = 0;
ASSERT_TRUE(data_by_example_
->Visit([&](const DataByExample::Data& data) {
++num_elements;
total_dual += data.dual;
})
.ok());
EXPECT_EQ(kNumElements, num_elements);
EXPECT_DOUBLE_EQ(
// 0 + 1 + ... + (N-1) = (N-1)*N/2
(kNumElements - 1) * kNumElements / 2.0, total_dual);
}
TEST_F(DataByExampleTest, VisitUnavailable) {
// Populate enough entries so that Visiting will be chunked.
for (size_t i = 0; i < 2 * VisitChunkSize(); ++i) {
data_by_example_->Get(DataByExample::MakeKey(strings::StrCat(i)));
}
struct Condition {
mutex mu;
bool c GUARDED_BY(mu) = false;
condition_variable cv;
};
auto signal = [](Condition* const condition) {
mutex_lock l(condition->mu);
condition->c = true;
condition->cv.notify_all();
};
auto wait = [](Condition* const condition) {
mutex_lock l(condition->mu);
while (!condition->c) {
condition->cv.wait(l);
}
};
Condition paused_visit; // Signaled after a Visit has paused.
Condition updated_data; // Signaled after data has been updated.
Condition completed_visit; // Signaled after a Visit has completed.
thread::ThreadPool thread_pool(Env::Default(), "test", 2 /* num_threads */);
Status status;
size_t num_visited = 0;
thread_pool.Schedule([&] {
status = data_by_example_->Visit([&](const DataByExample::Data& unused) {
++num_visited;
if (num_visited == VisitChunkSize()) {
// Safe point to mutate the data structure without a lock below.
signal(&paused_visit);
wait(&updated_data);
}
});
signal(&completed_visit);
});
thread_pool.Schedule([&, this] {
wait(&paused_visit);
InsertReservedEntryUnlocked();
signal(&updated_data);
});
wait(&completed_visit);
EXPECT_TRUE(errors::IsUnavailable(status));
}
} // namespace tensorflow

View File

@ -28,14 +28,12 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h"
#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
#include "tensorflow/contrib/linear_optimizer/kernels/squared-loss.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -47,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/group_iterator.h"
@ -596,22 +595,18 @@ class FeaturesAndWeights {
};
Status RunTrainStepsForMiniBatch(
const int num_examples, const TTypes<const string>::Vec example_ids,
const TTypes<const float>::Vec example_labels,
const int num_examples, const TTypes<const float>::Vec example_labels,
const TTypes<const float>::Vec example_weights,
const DeviceBase::CpuWorkerThreads& worker_threads,
const Regularizations& regularizations, const DualLossUpdater& loss_updater,
FeaturesAndWeights* const features_and_weights,
DataByExample* const data_by_example) {
TTypes<float>::Matrix example_state_data) {
// Process examples in parallel, in a partitioned fashion.
mutex mu;
Status train_step_status GUARDED_BY(mu);
auto train_step = [&](const int64 begin, const int64 end) {
for (int64 example_index = begin; example_index < end; ++example_index) {
// Get example id, label, and weight.
const DataByExample::EphemeralKey example_key =
DataByExample::MakeKey(example_ids(example_index));
DataByExample::Data data = data_by_example->Get(example_key);
const float dual = example_state_data(example_index, 0);
const float example_weight = example_weights(example_index);
float example_label = example_labels(example_index);
const Status conversion_status =
@ -633,24 +628,23 @@ Status RunTrainStepsForMiniBatch(
const double primal_loss = loss_updater.ComputePrimalLoss(
per_example_data.wx, example_label, example_weight);
const double dual_loss = loss_updater.ComputeDualLoss(
data.dual, example_label, example_weight);
const double dual_loss =
loss_updater.ComputeDualLoss(dual, example_label, example_weight);
const double new_dual = loss_updater.ComputeUpdatedDual(
example_label, example_weight, data.dual, per_example_data.wx,
example_label, example_weight, dual, per_example_data.wx,
per_example_data.normalized_squared_norm, primal_loss, dual_loss);
// Compute new weights.
const double bounded_dual_delta = (new_dual - data.dual) * example_weight;
const double bounded_dual_delta = (new_dual - dual) * example_weight;
features_and_weights->UpdateDeltaWeights(
example_index, bounded_dual_delta, regularizations.symmetric_l2());
// Update example data.
data.dual = new_dual;
data.primal_loss = primal_loss;
data.dual_loss = dual_loss;
data.example_weight = example_weight;
data_by_example->Set(example_key, data);
example_state_data(example_index, 0) = new_dual;
example_state_data(example_index, 1) = primal_loss;
example_state_data(example_index, 2) = dual_loss;
example_state_data(example_index, 3) = example_weight;
}
};
// TODO(sibyl-Aix6ihai): Current multiplier 100000 works well empirically
@ -689,30 +683,9 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES_OK(context, regularizations_.Initialize(context));
OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
&num_inner_iterations_));
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
}
void Compute(OpKernelContext* context) override {
// Get a handle on a shared container across invocations of this Kernel.
// The shared container is intended to maintain state at the example level
// across invocations of the kernel on different input data.
//
// TODO(sibyl-Mooth6ku): Replace this in-Kernel data structure with a first class
// citizen mutable Dictionary in tensorflow proper, that we will initialize
// and update externally.
DataByExample* data_by_example = nullptr;
OP_REQUIRES_OK(context,
context->resource_manager()->LookupOrCreate<DataByExample>(
container_, solver_uuid_, &data_by_example,
[this](DataByExample** ret) {
*ret = new DataByExample(container_, solver_uuid_);
return Status::OK();
}));
OP_REQUIRES(
context, !data_by_example->RefCountIsOne(),
errors::Internal("Expected shared-ownership of data_by_example."));
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input("example_weights", &example_weights_t));
@ -738,16 +711,19 @@ class SdcaSolver : public OpKernel {
"number of example weights (%d).",
example_labels.size(), num_examples)));
const Tensor* example_ids_t;
OP_REQUIRES_OK(context, context->input("example_ids", &example_ids_t));
OP_REQUIRES(context, TensorShapeUtils::IsVector(example_ids_t->shape()),
errors::InvalidArgument("example_ids should be a vector."));
const auto example_ids = example_ids_t->vec<string>();
OP_REQUIRES(context, example_labels.size() == num_examples,
errors::InvalidArgument(strings::Printf(
"The number of example ids (%ld) should match the number "
"of example weights (%d).",
example_ids.size(), num_examples)));
const Tensor* example_state_data_t;
OP_REQUIRES_OK(context,
context->input("example_state_data", &example_state_data_t));
TensorShape expected_example_state_shape({num_examples, 4});
OP_REQUIRES(
context, example_state_data_t->shape() == expected_example_state_shape,
errors::InvalidArgument("Expected shape ",
expected_example_state_shape.DebugString(),
" for example_state_data, got ",
example_state_data_t->shape().DebugString()));
Tensor mutable_example_state_data_t(*example_state_data_t);
auto example_state_data = mutable_example_state_data_t.matrix<float>();
FeaturesAndWeights features_and_weights;
OP_REQUIRES_OK(context,
@ -757,17 +733,15 @@ class SdcaSolver : public OpKernel {
for (int i = 0; i < num_inner_iterations_; ++i) {
OP_REQUIRES_OK(
context,
RunTrainStepsForMiniBatch(
num_examples, example_ids, example_labels, example_weights,
*context->device()->tensorflow_cpu_worker_threads(),
regularizations_, *loss_updater_, &features_and_weights,
data_by_example));
context, RunTrainStepsForMiniBatch(
num_examples, example_labels, example_weights,
*context->device()->tensorflow_cpu_worker_threads(),
regularizations_, *loss_updater_, &features_and_weights,
example_state_data));
}
features_and_weights.AddDeltaWeights();
// TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
data_by_example->Unref();
context->set_output(0, mutable_example_state_data_t);
}
private:
@ -779,8 +753,6 @@ class SdcaSolver : public OpKernel {
int64 num_dense_features_;
Regularizations regularizations_;
int num_inner_iterations_;
string container_;
string solver_uuid_;
};
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
@ -803,72 +775,26 @@ class SdcaShrinkL1 : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
class SdcaTrainingStats : public OpKernel {
class SdcaFprint : public OpKernel {
public:
explicit SdcaTrainingStats(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
explicit SdcaFprint(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
Tensor* out;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
const auto in_values = input.flat<string>();
auto out_values = out->flat<string>();
for (int64 i = 0; i < in_values.size(); ++i) {
const string& s = in_values(i);
Fprint128 fprint = Fingerprint128(s);
out_values(i) = strings::StrCat(strings::FpToString(fprint.high64), "-",
strings::FpToString(fprint.low64));
}
}
void Compute(OpKernelContext* context) override {
DataByExample* data_by_example = nullptr;
OP_REQUIRES_OK(context, context->resource_manager()->Lookup<DataByExample>(
container_, solver_uuid_, &data_by_example));
OP_REQUIRES(
context, !data_by_example->RefCountIsOne(),
errors::Internal("Expected shared-ownership of data_by_example."));
double total_primal_loss = 0;
double total_dual_loss = 0;
double total_example_weight = 0;
OP_REQUIRES_OK(context,
data_by_example->Visit([&](const DataByExample::Data& data) {
total_primal_loss += data.primal_loss;
total_dual_loss += data.dual_loss;
total_example_weight += data.example_weight;
}));
// TODO(sibyl-Mooth6ku): Think about the most arithmetically stable way of
// computing (dual + primal) loss (if it matters).
{
Tensor* tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("primal_loss", {}, &tensor));
tensor->scalar<double>()() = total_primal_loss;
}
{
Tensor* tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("dual_loss", {}, &tensor));
tensor->scalar<double>()() = total_dual_loss;
}
{
OP_REQUIRES(
context, total_example_weight > 0,
errors::FailedPrecondition(
"No examples found or all examples have zero weight. Either the "
"optimizer was trained with no instances or perhaps there is a "
"bug in the training data."));
Tensor* tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("example_weights", {}, &tensor));
tensor->scalar<double>()() = total_example_weight;
}
// TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
data_by_example->Unref();
}
private:
string container_;
string solver_uuid_;
};
REGISTER_KERNEL_BUILDER(Name("SdcaTrainingStats").Device(DEVICE_CPU),
SdcaTrainingStats);
REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
} // namespace tensorflow

View File

@ -25,16 +25,15 @@ REGISTER_OP("SdcaSolver")
.Attr("l1: float")
.Attr("l2: float")
.Attr("num_inner_iterations: int >= 1")
.Attr("container: string")
.Attr("solver_uuid: string")
.Input("sparse_features_indices: num_sparse_features * int64")
.Input("sparse_features_values: num_sparse_features * float")
.Input("dense_features: num_dense_features * float")
.Input("example_weights: float")
.Input("example_labels: float")
.Input("example_ids: string")
.Input("sparse_weights: Ref(num_sparse_features * float)")
.Input("dense_weights: Ref(num_dense_features * float)")
.Input("example_state_data: float")
.Output("example_data_data_out: float")
.Doc(R"doc(
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
L1 + L2 regularization. As global optimization objective is strongly-convex, the
@ -54,9 +53,6 @@ num_dense_features: Number of dense feature groups to train on.
l1: Symmetric l1 regularization strength.
l2: Symmetric l2 regularization strength.
num_inner_iterations: Number of iterations per mini-batch.
container: Name of the Container that stores data across invocations of this
Kernel. Together with SolverUUID form an isolation unit for this solver.
solver_uuid: Universally Unique Identifier for this solver.
sparse_features_indices: a list of matrices with two columns that contain
example_indices, and feature_indices.
sparse_features_values: a list of vectors which contains feature value
@ -66,12 +62,13 @@ example_weights: a vector which contains the weight associated with each
example.
example_labels: a vector which contains the label/target associated with each
example.
example_ids: a vector which contains the unique identifier associated with each
example.
sparse_weights: a list of vectors where each value is the weight associated with
a feature group.
dense_weights: a list of vectors where the value is the weight associated with
a dense feature group.
example_state_data: a list of vectors containing the example state data.
example_data_data_out: a list of vectors containing the updated example state
data.
)doc");
REGISTER_OP("SdcaShrinkL1")
@ -94,23 +91,14 @@ dense_weights: a list of vectors where the value is the weight associated with
a dense feature group.
)doc");
REGISTER_OP("SdcaTrainingStats")
.Attr("container: string")
.Attr("solver_uuid: string")
.Output("primal_loss: float64")
.Output("dual_loss: float64")
.Output("example_weights: float64")
REGISTER_OP("SdcaFprint")
.Input("input: string")
.Output("output: string")
.Doc(R"doc(
Computes statistics over all examples seen by the optimizer.
Computes fingerprints of the input strings.
container: Name of the Container that stores data across invocations of this
Kernel. Together with SolverUUID form an isolation unit for this solver.
solver_uuid: Universally Unique Identifier for this solver.
primal_loss: total primal loss of all examples seen by the optimizer.
dual_loss: total dual loss of all examples seen by the optimizer.
example_weights: total example weights of all examples seen by the optimizer
(guaranteed to be positive; otherwise returns FAILED_PRECONDITION as it
probably indicates a bug in the training data).
input: strings to compute fingerprints on.
output: the computed fingerprints.
)doc");
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ import uuid
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _sdca_ops
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.platform import googletest
@ -138,37 +139,6 @@ class SdcaOptimizerTest(TensorFlowTestCase):
intra_op_parallelism_threads=1)
return self.test_session(use_gpu=False, config=config)
# The following tests, check that operations raise errors when certain
# preconditions on the input data are not satisfied. These errors are raised
# regardless of the loss type.
def testNoWeightedExamples(self):
# Setup test data with 1 positive, and 1 negative example.
example_protos = [
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
]
# Zeroed out example weights.
example_weights = [0.0, 0.0]
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
lr.minimize().run()
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
with self.assertRaisesOpError(
'No examples found or all examples have zero weight.'):
lr.approximate_duality_gap().eval()
class SdcaWithLogisticLossTest(SdcaOptimizerTest):
"""SDCA optimizer test class for logistic loss."""
@ -815,5 +785,18 @@ class SdcaWithHingeLossTest(SdcaOptimizerTest):
self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)
class SdcaFprintTest(TensorFlowTestCase):
"""Tests for the SdcaFprint op."""
def testFprint(self):
with self.test_session():
in_data = tf.constant(['abc', 'very looooooong string', 'def'])
out_data = _sdca_ops.sdca_fprint(in_data)
self.assertAllEqual([b'a085f09013029e45-3980b2afd2126c04',
b'bc5a254df959f26c-512e479a50910f9f',
b'79999cd817a03f12-085f182230e03022'],
out_data.eval())
if __name__ == '__main__':
googletest.main()

View File

@ -17,11 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import uuid
from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.contrib.linear_optimizer.ops import gen_sdca_ops
from tensorflow.contrib.lookup import lookup_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.load_library import load_op_library
@ -106,10 +105,11 @@ class SdcaModel(object):
```
"""
def __init__(self, container, examples, variables, options):
def __init__(self, container, examples, variables, options): # pylint: disable=unused-argument
"""Create a new sdca optimizer."""
# TODO(andreasst): get rid of obsolete container parameter
if not container or not examples or not variables or not options:
if not examples or not variables or not options:
raise ValueError('All arguments must be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss')
@ -136,12 +136,12 @@ class SdcaModel(object):
raise ValueError('%s should be non-negative. Found (%f)' %
(name, value))
self._container = container
self._examples = examples
self._variables = variables
self._options = options
self._solver_uuid = uuid.uuid4().hex
self._create_slots()
self._hashtable = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32,
[0.0, 0.0, 0.0, 0.0])
def _symmetric_l2_regularization(self):
# Algorithmic requirement (for now) is to have minimal l2 of 1.0
@ -264,19 +264,23 @@ class SdcaModel(object):
sparse_features_indices.append(convert_to_tensor(sf.indices))
sparse_features_values.append(convert_to_tensor(sf.values))
step_op = _sdca_ops.sdca_solver(
example_ids_hashed = _sdca_ops.sdca_fprint(convert_to_tensor(
self._examples['example_ids']))
example_state_data = self._hashtable.lookup(example_ids_hashed)
example_state_data_updated = _sdca_ops.sdca_solver(
sparse_features_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
convert_to_tensor(self._examples['example_weights']),
convert_to_tensor(self._examples['example_labels']),
convert_to_tensor(self._examples['example_ids']),
self._convert_n_to_tensor(
self._slots['unshrinked_sparse_features_weights'],
as_ref=True),
self._convert_n_to_tensor(
self._slots['unshrinked_dense_features_weights'],
as_ref=True),
example_state_data,
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
# TODO(sibyl-Aix6ihai): Provide empirical evidence for this. It is better
@ -286,17 +290,17 @@ class SdcaModel(object):
# reuse old samples than train on new samples.
# See: http://arxiv.org/abs/1602.02136.
num_inner_iterations=2,
loss_type=self._options['loss_type'],
container=self._container,
solver_uuid=self._solver_uuid)
with ops.control_dependencies([step_op]):
assign_ops = []
loss_type=self._options['loss_type'])
with ops.control_dependencies([example_state_data_updated]):
insert_op = self._hashtable.insert(example_ids_hashed,
example_state_data_updated)
update_ops = [insert_op]
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
assign_ops.append(var.assign(slot_var))
assign_group = control_flow_ops.group(*assign_ops)
with ops.control_dependencies([assign_group]):
update_ops.append(var.assign(slot_var))
update_group = control_flow_ops.group(*update_ops)
with ops.control_dependencies([update_group]):
shrink_l1 = _sdca_ops.sdca_shrink_l1(
self._convert_n_to_tensor(
self._variables['sparse_features_weights'],
@ -318,14 +322,17 @@ class SdcaModel(object):
An Operation that computes the approximate duality gap over all
examples.
"""
(primal_loss, dual_loss, example_weights) = _sdca_ops.sdca_training_stats(
container=self._container,
solver_uuid=self._solver_uuid)
# Note that example_weights is guaranteed to be positive by
# sdca_training_stats so dividing by it is safe.
return (primal_loss + dual_loss + math_ops.to_double(self._l1_loss()) +
(2.0 * math_ops.to_double(self._l2_loss(
self._symmetric_l2_regularization())))) / example_weights
_, exported_values = self._hashtable.export()
summed_values = math_ops.reduce_sum(exported_values, 0)
primal_loss = summed_values[1]
dual_loss = summed_values[2]
example_weights = summed_values[3]
# TODO(andreasst): what about handle examples_weights == 0?
return (
primal_loss + dual_loss + math_ops.to_float(self._l1_loss()) +
(2.0 *
math_ops.to_float(self._l2_loss(self._symmetric_l2_regularization())))
) / example_weights
def unregularized_loss(self, examples):
"""Add operations to compute the loss (without the regularization loss).

View File

@ -32,7 +32,7 @@ py_library(
deps = [
":gc",
":manifest_proto_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
],
)
@ -57,7 +57,7 @@ py_library(
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
],
)

View File

@ -25,13 +25,15 @@ import os
import re
import six
import tensorflow as tf
from google.protobuf.any_pb2 import Any
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
@ -62,20 +64,20 @@ def gfile_copy_callback(files_to_copy, export_dir_path):
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
tf.logging.info("Removing file %s.", new_path)
logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
tf.gfile.Copy(source_filepath, new_path)
gfile.Copy(source_filepath, new_path)
def regression_signature(input_tensor, output_tensor):
@ -188,22 +190,22 @@ class Exporter(object):
self._has_init = True
if graph_def or clear_devices:
copy = tf.GraphDef()
copy = graph_pb2.GraphDef()
if graph_def:
copy.CopyFrom(graph_def)
else:
copy.CopyFrom(tf.get_default_graph().as_graph_def())
copy.CopyFrom(ops.get_default_graph().as_graph_def())
if clear_devices:
for node in copy.node:
node.device = ""
graph_any_buf = Any()
graph_any_buf.Pack(copy)
tf.add_to_collection(GRAPH_KEY, graph_any_buf)
ops.add_to_collection(GRAPH_KEY, graph_any_buf)
if init_op:
if not isinstance(init_op, ops.Operation):
raise TypeError("init_op needs to be an Operation: %s" % init_op)
tf.add_to_collection(INIT_OP_KEY, init_op)
ops.add_to_collection(INIT_OP_KEY, init_op)
signatures_proto = manifest_pb2.Signatures()
if default_graph_signature:
@ -212,7 +214,7 @@ class Exporter(object):
signatures_proto.named_signatures[signature_name].CopyFrom(signature)
signatures_any_buf = Any()
signatures_any_buf.Pack(signatures_proto)
tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
for filename, tensor in assets:
asset = manifest_pb2.AssetFile()
@ -220,7 +222,7 @@ class Exporter(object):
asset.tensor_binding.tensor_name = tensor.name
asset_any_buf = Any()
asset_any_buf.Pack(asset)
tf.add_to_collection(ASSETS_KEY, asset_any_buf)
ops.add_to_collection(ASSETS_KEY, asset_any_buf)
self._assets_callback = assets_callback
@ -250,6 +252,10 @@ class Exporter(object):
if not self._has_init:
raise RuntimeError("init must be called first")
# Export dir must not end with / or it will break exports to keep. Strip /.
if export_dir_base.endswith("/"):
export_dir_base = export_dir_base[:-1]
global_step = training_util.global_step(sess, global_step_tensor)
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
@ -299,11 +305,11 @@ class Exporter(object):
def _file_path_value(self, path_tensor):
"""Returns the filepath value stored in constant `path_tensor`."""
if not isinstance(path_tensor, tf.Tensor):
if not isinstance(path_tensor, ops.Tensor):
raise TypeError("tensor is not a Tensor")
if path_tensor.op.type != "Const":
raise TypeError("Only constants tensor are supported")
if path_tensor.dtype != tf.string:
if path_tensor.dtype != dtypes.string:
raise TypeError("File paths should be string")
str_value = path_tensor.op.get_attr("value").string_val
if len(str_value) != 1:

View File

@ -1127,23 +1127,6 @@ cc_library(
alwayslink = 1,
)
tf_cuda_library(
name = "cupti_wrapper_default",
srcs = [
"platform/default/gpu/cupti_wrapper.cc",
],
hdrs = [
"platform/default/gpu/cupti_wrapper.h",
],
copts = tf_copts(),
cuda_deps = [
":stream_executor",
"//third_party/gpus/cuda:cuda_headers",
"//third_party/gpus/cuda:cupti_headers",
],
data = ["//third_party/gpus/cuda:cupti_dsos"],
)
tf_cuda_library(
name = "gpu_tracer",
srcs = [

View File

@ -184,27 +184,27 @@ class SimpleRendezvous : public Rendezvous {
public:
explicit SimpleRendezvous() {}
Status Send(const string& key, const Args& send_args, const Tensor& val,
Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val,
const bool is_dead) override {
if (is_dead) {
return errors::Internal("Send of a dead tensor");
}
ParsedKey parsed;
TF_RETURN_IF_ERROR(ParseKey(key, &parsed));
mutex_lock l(mu_);
if (table_.count(parsed.edge_name) > 0) {
string edge_name = parsed.edge_name.ToString();
if (table_.count(edge_name) > 0) {
return errors::Internal("Send of an already sent tensor");
}
table_[parsed.edge_name] = val;
table_[edge_name] = val;
return Status::OK();
}
void RecvAsync(const string& key, const Args& recv_args,
void RecvAsync(const ParsedKey& parsed, const Args& recv_args,
DoneCallback done) override {
Tensor tensor;
Status status = Status::OK();
{
string key = parsed.edge_name.ToString();
mutex_lock l(mu_);
if (table_.count(key) <= 0) {
status = errors::Internal("Did not find key ", key);
@ -417,7 +417,14 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
// this node. Don't bother processing the rest of the nodes.
return c > 0;
}
Status s = rendez->Recv(tensor_name, Rendezvous::Args(), &output, &is_dead);
string full_key = Rendezvous::CreateKey("/cpu:0", 1, "/cpu:1", tensor_name,
FrameAndIter(0, 0));
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(full_key, &parsed);
if (s.ok()) {
s = rendez->Recv(parsed, Rendezvous::Args(), &output, &is_dead);
}
if (!s.ok() || is_dead) {
return c > 0;
}

View File

@ -43,8 +43,7 @@ std::vector<RegistrationInfo>* MutableRegistry() {
} // namespace
// static
void CopyTensor::ViaDMA(const string& edge_name,
DeviceContext* send_dev_context,
void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
DeviceContext* recv_dev_context, Device* src,
Device* dst, const AllocatorAttributes src_alloc_attr,
const AllocatorAttributes dst_alloc_attr,

View File

@ -42,7 +42,7 @@ class CopyTensor {
// the type of devices and memory in use, the copy may be performed
// synchronously or asynchronously. 'done' will be invoked only
// after the copy is actually complete.
static void ViaDMA(const string& edge_name, DeviceContext* send_dev_context,
static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
DeviceContext* recv_dev_context, Device* src, Device* dst,
const AllocatorAttributes src_alloc_attr,
const AllocatorAttributes dst_alloc_attr,

View File

@ -24,13 +24,17 @@ limitations under the License.
namespace tensorflow {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
: name_backing_store_(128) {
for (Device* d : devices) {
devices_.push_back(d);
// Register under both the full name and the local name.
device_map_[d->name()] = d;
device_map_[DeviceNameUtils::LocalName(d->name())] = d;
string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d;
string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++;
}
}
@ -39,6 +43,13 @@ DeviceMgr::~DeviceMgr() {
for (auto p : devices_) delete p;
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
int n = s.size();
char* space = name_backing_store_.Alloc(n);
memcpy(space, s.data(), n);
return StringPiece(space, n);
}
void DeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size());
@ -70,7 +81,7 @@ string DeviceMgr::DeviceMappingString() const {
return out;
}
Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s;
auto iter = device_map_.find(name);
if (iter == device_map_.end()) {

View File

@ -22,7 +22,9 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/arena.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/macros.h"
@ -49,7 +51,7 @@ class DeviceMgr {
// Assigns *device with pointer to Device of the given name.
// Accepts either a full device name, or just the replica-local suffix.
Status LookupDevice(const string& name, Device** device) const;
Status LookupDevice(StringPiece name, Device** device) const;
// Clears given containers of all devices if 'container' is
// non-empty. Otherwise, clears default containers of all devices.
@ -60,7 +62,11 @@ class DeviceMgr {
private:
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;
std::unordered_map<string, Device*> device_map_;
StringPiece CopyToBackingStore(StringPiece s);
std::unordered_map<StringPiece, Device*, StringPiece::Hasher> device_map_;
core::Arena name_backing_store_; // Storage for keys in device_map_
std::unordered_map<string, int> device_type_counts_;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);

View File

@ -551,6 +551,7 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez) {
Status s;
Rendezvous::ParsedKey parsed;
// Insert the input tensors into the local rendezvous by their
// rendezvous key.
for (const auto& input : inputs) {
@ -560,7 +561,14 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
"' is not a pre-defined feed!");
}
const string& input_key = it->second;
s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
s = Rendezvous::ParseKey(input_key, &parsed);
if (!s.ok()) {
rendez->StartAbort(s);
return s;
}
s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
if (!s.ok()) {
rendez->StartAbort(s);
return s;
@ -578,6 +586,7 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
outputs->resize(output_names.size());
}
Rendezvous::ParsedKey parsed;
// Get the outputs from the rendezvous
for (size_t output_offset = 0; output_offset < output_names.size();
++output_offset) {
@ -591,14 +600,16 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
const string& output_key = it->second;
Tensor output_tensor;
bool is_dead;
// Fetch data from the Rendezvous.
IntraProcessRendezvous* rendez = run_state->rendez;
s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
if (is_dead && s.ok()) {
s = errors::InvalidArgument("The tensor returned for ",
output_names[output_offset],
" was not valid.");
s = Rendezvous::ParseKey(output_key, &parsed);
if (s.ok()) {
// Fetch data from the Rendezvous.
s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead);
if (is_dead && s.ok()) {
s = errors::InvalidArgument("The tensor returned for ", output_name,
" was not valid.");
}
}
if (!s.ok()) {
rendez->StartAbort(s);

View File

@ -27,6 +27,8 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
thread::ThreadPool* pool_ = nullptr;

View File

@ -645,6 +645,43 @@ class ExecutorState {
}
};
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
// have that many nodes in the ready queue, so we just use a vector and
// don't free up memory from the queue as we consume nodes.
class TaggedNodeReadyQueue {
public:
TaggedNodeReadyQueue() : front_index_(0) {}
void push_back(TaggedNode node) { ready_.push_back(node); }
TaggedNode front() const {
DCHECK_LT(front_index_, ready_.size());
return ready_[front_index_];
}
void pop_front() {
DCHECK_LT(front_index_, ready_.size());
front_index_++;
if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
if (front_index_ == ready_.size()) {
ready_.clear();
} else {
// Lots of unused entries at beginning of vector: move everything down
// to start of vector.
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
}
front_index_ = 0;
}
}
bool empty() const { return ready_.empty(); }
const TaggedNode* begin() const { return ready_.begin() + front_index_; }
const TaggedNode* end() const { return ready_.end(); }
private:
gtl::InlinedVector<TaggedNode, 16> ready_;
int front_index_;
};
struct AsyncState;
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
typedef gtl::InlinedVector<Entry, 4> EntryVector;
@ -767,15 +804,15 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready);
// Call Process() on all nodes in 'inline_ready'.
void ProcessInline(const std::deque<TaggedNode>& inline_ready);
void ProcessInline(const TaggedNodeReadyQueue& inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
// nodes in 'ready' into 'inline_ready'.
void ScheduleReady(const TaggedNodeSeq& ready,
std::deque<TaggedNode>* inline_ready);
TaggedNodeReadyQueue* inline_ready);
// Provide debugging output about an outstanding node in the executor.
void DumpCompletedNodeState(const int node_id, const Entry* input_vector);
@ -905,43 +942,55 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
}
}
namespace {
// State kept alive for executing an asynchronous node in another
// thread. NOTE: We need to make a copy of p.input,
// p.input_device_contexts, and p.input_alloc_attrs for asynchronous
// kernels because OpKernelContext methods like input_type(i) needs
// the param points to valid input type vector. It's not an issue for
// sync kernels because these vectors are kept on the stack.
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem& _item, Entry* _first_input, NodeExecStats* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
params(p),
tagged_node(_tagged_node),
item(_item),
first_input(_first_input),
// ParamsButClearingEigenGPUDevice does equivalent of
// params.eigen_gpu_device = nullptr;
ctx(ParamsButClearingEigenGPUDevice(&params), item.num_outputs),
stats(_stats) {
params.inputs = &saved_inputs;
params.input_device_contexts = &saved_input_device_contexts;
params.input_alloc_attrs = &saved_input_alloc_attrs;
}
// Helpers to make a copy of 'p' and makes a copy of the input type
// vector and the device context vector.
//
// NOTE: We need to make a copy of p.input for asynchronous kernel
// because OpKernelContext methods like input_type(i) needs the param
// points to valid input type vector. It's not an issue for sync
// kernels because the type vector is kept on the stack.
OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
OpKernelContext::Params* ret = new OpKernelContext::Params;
*ret = p;
// Ensure the copy of Params will make a new eigen GPU device if
// necessary.
ret->eigen_gpu_device = nullptr;
ret->inputs = new TensorValueVec(*p.inputs);
ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
return ret;
}
TensorValueVec saved_inputs;
DeviceContextVec saved_input_device_contexts;
AllocatorAttributeVec saved_input_alloc_attrs;
OpKernelContext::Params params;
TaggedNode tagged_node;
NodeItem item;
Entry* first_input;
OpKernelContext ctx;
NodeExecStats* stats;
// Helpers to delete 'p' and copies made by CopyParams.
void DeleteParams(OpKernelContext::Params* p) {
// No need to delete p->eigen_gpu_device since that is deleted in
// p's destructor
delete p->inputs;
delete p->input_device_contexts;
delete p->input_alloc_attrs;
delete p;
}
} // namespace
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
OpKernelContext::Params* p) {
// Ensure OpKernelContext constructor will make a new eigen GPU device if
// necessary.
p->eigen_gpu_device = nullptr; // Force allocation
return p;
}
};
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
const NodeItem* nodes = impl_->nodes_;
TaggedNodeSeq ready;
std::deque<TaggedNode> inline_ready;
TaggedNodeReadyQueue inline_ready;
// Parameters passed to OpKernel::Compute.
TensorValueVec inputs;
@ -1059,20 +1108,25 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
AsyncOpKernel* async = item.kernel->AsAsync();
DCHECK(async != nullptr);
launched_asynchronously = true;
auto pcopy = CopyParams(params);
auto ctx = new OpKernelContext(pcopy, item.num_outputs);
auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy,
device]() {
AsyncState* state =
new AsyncState(params, tagged_node, item, first_input, stats);
auto done = [this, state]() {
Device* device = impl_->params_.device;
NodeExecStats* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
if (vlog_) {
VLOG(2) << this << " Async kernel done: "
<< SummarizeNodeDef(item.node->def());
<< SummarizeNodeDef(state->item.node->def());
}
if (stats_collector_) nodestats::SetOpEnd(stats);
EntryVector outputs;
Status s = ProcessOutputs(item, ctx, &outputs, stats);
if (stats_collector_) nodestats::SetMemory(stats, ctx);
Status s = ProcessOutputs(state->item, &state->ctx, &outputs, stats);
if (stats_collector_) nodestats::SetMemory(stats, &state->ctx);
// Clears inputs.
int num_inputs = item.num_inputs;
const int num_inputs = state->item.num_inputs;
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->val = *kEmptyTensor;
}
@ -1080,31 +1134,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
mutex_lock l(mu_);
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
->mark_completed(tagged_node.node->id());
state->tagged_node.input_frame
->GetIteration(state->tagged_node.input_iter)
->mark_completed(state->tagged_node.node->id());
}
TaggedNodeSeq ready;
if (s.ok()) {
PropagateOutputs(tagged_node, outputs, &ready);
PropagateOutputs(state->tagged_node, outputs, &ready);
}
outputs.clear();
if (s.ok() && pcopy->device->RequiresRecordingAccessedTensors()) {
if (s.ok() &&
state->params.device->RequiresRecordingAccessedTensors()) {
// Get the list of all tensors accessed during the execution
TensorReferenceVector accessed;
ctx->retrieve_accessed_tensors(&accessed);
state->ctx.retrieve_accessed_tensors(&accessed);
if (stats_collector_)
nodestats::SetReferencedTensors(stats, accessed);
// callee takes ownership of the vector
device->ConsumeListOfAccessedTensors(ctx->op_device_context(),
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
}
bool completed = NodeDone(s, item.node, ready, stats, nullptr);
delete ctx;
DeleteParams(pcopy);
bool completed = NodeDone(s, state->item.node, ready, stats, nullptr);
delete state;
if (completed) Finish();
};
if (stats_collector_) nodestats::SetOpStart(stats);
device->ComputeAsync(async, ctx, done);
device->ComputeAsync(async, &state->ctx, done);
} else {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
@ -1497,7 +1552,7 @@ void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready, NodeExecStats* stats,
std::deque<TaggedNode>* inline_ready) {
TaggedNodeReadyQueue* inline_ready) {
if (stats_collector_) {
nodestats::SetAllEnd(stats);
stats_collector_->UpdateCostModelNode(stats, impl_->graph_, node);
@ -1542,7 +1597,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
return completed;
}
void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
void ExecutorState::ProcessInline(const TaggedNodeReadyQueue& inline_ready) {
if (inline_ready.empty()) return;
int64 scheduled_usec = 0;
if (stats_collector_) {
@ -1554,7 +1609,7 @@ void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
}
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
std::deque<TaggedNode>* inline_ready) {
TaggedNodeReadyQueue* inline_ready) {
if (ready.empty()) return;
int64 scheduled_usec = 0;

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@ -30,7 +30,7 @@ void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
const string& tensor_name,
StringPiece tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);

View File

@ -56,9 +56,9 @@ class GPUDeviceContext : public DeviceContext {
Tensor* device_tensor,
StatusCallback done) const override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
const string& edge_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) override;
void MaintainLifetimeOnStream(
const Tensor* t, perftools::gputools::Stream* stream) const override {}

View File

@ -39,7 +39,6 @@ namespace test {
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init) {
SessionOptions default_options;
if (!options) {
options = &default_options;
@ -136,25 +135,35 @@ void Benchmark::RunWithArgs(
args.runner = [this](std::function<void()> closure) {
pool_->Schedule(closure);
};
for (int i = 0; i < 3; ++i) {
static const int kWarmupRuns = 3;
for (int i = 0; i < kWarmupRuns; ++i) {
for (const auto& p : in) {
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
Rendezvous::ParsedKey parsed;
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
}
TF_CHECK_OK(exec_->Run(args));
for (const string& key : out) {
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
Rendezvous::ParsedKey parsed;
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
}
}
TF_CHECK_OK(device_->Sync());
VLOG(3) << kWarmupRuns << " warmup runs done.";
testing::StartTiming();
while (iters-- > 0) {
for (const auto& p : in) {
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
Rendezvous::ParsedKey parsed;
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
}
TF_CHECK_OK(exec_->Run(args));
for (const string& key : out) {
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
Rendezvous::ParsedKey parsed;
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
}
}

View File

@ -36,19 +36,17 @@ IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
Status IntraProcessRendezvous::Send(const string& key,
Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key;
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
}
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
// Buffers "val" and "device_context" in local_.
return local_->Send(key, args, val, is_dead);
return local_->Send(parsed, args, val, is_dead);
}
Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
@ -111,24 +109,17 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
done);
}
void IntraProcessRendezvous::RecvAsync(const string& key,
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key;
Rendezvous::ParsedKey parsed;
Status s = ParseKey(key, false /*!is_src*/, &parsed);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
return;
}
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
// Recv the tensor from local_.
local_->RecvAsync(key, recv_args, [this, parsed, done](
const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
local_->RecvAsync(parsed, recv_args, [this, parsed, done](
const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
Status s = status;
Tensor* out = new Tensor;
StatusCallback final_callback = [done, send_args, recv_args, out,

View File

@ -43,14 +43,14 @@ class IntraProcessRendezvous : public Rendezvous {
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
Status Send(const string& key, const Rendezvous::Args& args,
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) override;
// This method is called only by the RecvOp. It tests to see
// whether the value will be produced by a local or remote device
// and handles accordingly. In the local case it forwards to
// local_, in the remote case it initiates an RPC request.
void RecvAsync(const string& key, const Rendezvous::Args& args,
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
DoneCallback done) override;
void StartAbort(const Status& status) override;

View File

@ -60,23 +60,25 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
return iter->second;
}
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
const Rendezvous::ParsedKey& parsed,
Rendezvous::DoneCallback done) {
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
rendez->RecvLocalAsync(
key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v,
bool dead) {
parsed, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v,
bool dead) {
rendez->Unref();
done(s, send_args, recv_args, v, dead);
});
}
Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
Status BaseRendezvousMgr::RecvLocal(int64 step_id,
const Rendezvous::ParsedKey& parsed,
Tensor* val, bool* is_dead) {
Status ret;
Notification n;
RecvLocalAsync(step_id, key,
RecvLocalAsync(step_id, parsed,
[val, is_dead, &ret, &n](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
@ -140,38 +142,35 @@ static bool IsLocalDevice(const WorkerEnv& worker,
return device_name.starts_with(worker.worker_name);
}
Status BaseRemoteRendezvous::Send(const string& key,
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
}
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
if (!IsLocalDevice(*env_, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
env_->worker_name);
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", env_->worker_name);
}
// Buffers "val" and "device_context" in local_.
return local_->Send(key, args, val, is_dead);
return local_->Send(parsed, args, val, is_dead);
}
Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
Rendezvous::ParsedKey* parsed) {
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) {
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
}
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
env_->worker_name);
if (is_src && !IsLocalDevice(*env_, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", env_->worker_name);
}
if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
env_->worker_name);
if (!is_src && !IsLocalDevice(*env_, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
parsed.FullKey(), " @ ", env_->worker_name);
}
return Status::OK();
}
@ -233,13 +232,11 @@ bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
return DeviceNameUtils::IsSameAddressSpace(src, dst);
}
void BaseRemoteRendezvous::RecvAsync(const string& key,
void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
Rendezvous::ParsedKey parsed;
Status s = ParseKey(key, false /*!is_src*/, &parsed);
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
return;
@ -247,12 +244,13 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
// Are src and dst in the same worker?
if (IsSameWorker(parsed.src, parsed.dst)) {
Rendezvous::ParsedKey parsed_copy = parsed;
// Recv the tensor from local_.
local_->RecvAsync(
key, recv_args, [this, parsed, done](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
parsed_copy, recv_args,
[this, parsed_copy, done](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
Status s = status;
Tensor* out = new Tensor;
StatusCallback final_callback = [done, send_args, recv_args, out,
@ -262,27 +260,26 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
};
if (s.ok()) {
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
final_callback);
SameWorkerRecvDone(parsed_copy, send_args, recv_args, in, out,
std::move(final_callback));
} else {
final_callback(s);
}
});
return;
} else {
RecvFromRemoteAsync(key, parsed, recv_args, done);
RecvFromRemoteAsync(parsed, recv_args, std::move(done));
}
}
void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) {
Rendezvous::ParsedKey parsed;
Status s = ParseKey(key, true /* is_src */, &parsed);
Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
return;
}
local_->RecvAsync(key, Args(), done);
local_->RecvAsync(parsed, Args(), std::move(done));
}
void BaseRemoteRendezvous::StartAbort(const Status& s) {

View File

@ -68,12 +68,12 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// "done" when the tensor for "key" is produced or an error occurs.
//
// This method is used by the rpc handler of RecvTensor.
void RecvLocalAsync(int64 step_id, const string& key,
void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
Rendezvous::DoneCallback done) override;
// Synchronous wrapper for RecvLocalAsync.
Status RecvLocal(int64 step_id, const string& key, Tensor* val,
bool* is_dead) override;
Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
Tensor* val, bool* is_dead) override;
// Removes rendezvous for "step_id".
//
@ -116,14 +116,14 @@ class BaseRemoteRendezvous : public Rendezvous {
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
Status Send(const string& key, const Rendezvous::Args& args,
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) override;
// This method is called only by the RecvOp. It tests to see
// whether the value will be produced by a local or remote device
// and handles accordingly. In the local case it forwards to
// local_, in the remote case it initiates an RPC request.
void RecvAsync(const string& key, const Rendezvous::Args& args,
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
DoneCallback done) override;
void StartAbort(const Status& status) override;
@ -134,15 +134,14 @@ class BaseRemoteRendezvous : public Rendezvous {
// network. In either case it needs to retrieve a locally buffered
// value from local_, and give it to its caller.
//
// Runs "done" as soon as the tensor for "key" is available or an error
// Runs "done" as soon as the tensor for "parsed" is available or an error
// is detected.
//
// REQUIRES: "key" is one that will be Saved into the local rendezvous.
void RecvLocalAsync(const string& key, DoneCallback done);
// REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
protected:
virtual void RecvFromRemoteAsync(const string& key,
const Rendezvous::ParsedKey& parsed,
virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) = 0;
@ -174,11 +173,10 @@ class BaseRemoteRendezvous : public Rendezvous {
// Active outstanding RecvTensor calls.
std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
// Parses "key" into "parsed". If "is_src" is true, checks that the
// rendezvous key's source is in this process. If "is_src" is false,
// checks that the rendezvous key's destination is in this process.
Status ParseKey(const string& key, bool is_src,
Rendezvous::ParsedKey* parsed);
// If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process.
Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
// Callback handling the case when a rendezvous has been
// accomplished in local_ and the consumer is local to this process.

View File

@ -33,7 +33,7 @@ void CallOptions::StartCancel() {
void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
mutex_lock l(mu_);
cancel_func_ = cancel_func;
cancel_func_ = std::move(cancel_func);
}
void CallOptions::ClearCancelCallback() {

View File

@ -127,10 +127,15 @@ float V(const Tensor& tensor) {
static uint64 kIncarnation = 1; // Uses in following tests.
string Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
return Rendezvous::CreateKey(sender, incarnation, receiver, name,
FrameAndIter(0, 0));
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
Rendezvous::ParsedKey result;
CHECK(
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
name, FrameAndIter(0, 0)),
&result)
.ok());
return result;
}
#define ALICE "/job:j/replica:0/task:0/cpu:0"

View File

@ -306,10 +306,15 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
// Sends values specified by the caller.
Rendezvous::ParsedKey parsed;
for (const auto& p : in) {
const string& key = p.first;
const Tensor& val = p.second;
const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
}
if (!s.ok()) {
done(s);
item->Unref();
@ -337,7 +342,10 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
LogMemory::RecordStep(args.step_id, handle);
}
thread::ThreadPool* pool = worker_env_->compute_pool;
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
using namespace std::placeholders;
// Line below is equivalent to this code, but does one less indirect call:
// args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
for (const auto& unit : item->units) {
unit.root->RunAsync(args, barrier->Get());
}
@ -347,11 +355,15 @@ void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
StatusCallback done, Status s) {
if (s.ok()) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
}
if (is_dead) {
s = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");

View File

@ -57,12 +57,13 @@ class RendezvousMgrInterface {
// "done" when the tensor for "key" is produced or an error occurs.
//
// This method is used by the rpc handler of RecvTensor.
virtual void RecvLocalAsync(int64 step_id, const string& key,
virtual void RecvLocalAsync(int64 step_id,
const Rendezvous::ParsedKey& parsed,
Rendezvous::DoneCallback done) = 0;
// Synchronous wrapper for RecvLocalAsync.
virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
bool* is_dead) = 0;
virtual Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
Tensor* val, bool* is_dead) = 0;
// Removes rendezvous for "step_id".
//

View File

@ -97,49 +97,66 @@ class GrpcRemoteWorker : public WorkerInterface {
req_copy->set_dma_ok(false);
}
// Type-specialized logging for this method.
StatusCallback logging_callback = [this, request, req_copy, response, done,
start_usec](Status s) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
int64 bytes = response->tensor().ByteSize();
int64 send_start_usec = start_usec;
// If a send start time was reported by the other side, use
// that instead. Maybe we should mark the display if we're using
// our local time instead of the remote start time?
if (response->send_start_micros()) {
// send_start_micros is the timestamp taken when the remote
// machine began to send the RecvTensor response.
// Due to clock skew between source and dest machines, it is
// possible that send_start_micros can be larger than end_usec or
// less than start_usec.
// To respect causality, we enforce the invariants that the RecvTensor
// response can not have been sent before the RecvTensor request, and
// must have been sent before it was received.
send_start_usec = std::max(start_usec, response->send_start_micros());
send_start_usec = std::min(send_start_usec, end_usec - 1);
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
StatusCallback wrapper_done;
const StatusCallback* cb_to_use;
if (!logging_active && req_copy == nullptr) {
cb_to_use = &done; // No additional work to do, so just use done directly
} else if (!logging_active) {
wrapper_done = [req_copy, done](Status s) {
delete req_copy;
done(s);
};
cb_to_use = &wrapper_done;
} else {
wrapper_done = [this, request, req_copy, response, done,
start_usec](Status s) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
int64 bytes = response->tensor().ByteSize();
int64 send_start_usec = start_usec;
// If a send start time was reported by the other side, use
// that instead. Maybe we should mark the display if we're using
// our local time instead of the remote start time?
if (response->send_start_micros()) {
// send_start_micros is the timestamp taken when the
// remote machine began to send the RecvTensor response.
// Due to clock skew between source and dest machines, it
// is possible that send_start_micros can be larger than
// end_usec or less than start_usec.
//
// To respect causality, we enforce the invariants that
// the RecvTensor response can not have been sent before
// the RecvTensor request, and must have been sent before
// it was received.
send_start_usec =
std::max(start_usec, response->send_start_micros());
send_start_usec = std::min(send_start_usec, end_usec - 1);
}
const string& key = request->rendezvous_key();
std::vector<string> key_parts = str_util::Split(key, ';');
if (key_parts.size() != 5) {
LOG(WARNING) << "Bad key: " << key;
} else {
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
key_parts[3], // tensor name
key_parts[0], // src_device
key_parts[2], // dst_device
bytes);
}
}
const string& key = request->rendezvous_key();
std::vector<string> key_parts = str_util::Split(key, ';');
if (key_parts.size() != 5) {
LOG(WARNING) << "Bad key: " << key;
} else {
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
key_parts[3], // tensor name
key_parts[0], // src_device
key_parts[2], // dst_device
bytes);
}
}
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->DebugString();
delete req_copy;
done(s);
};
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->DebugString();
delete req_copy;
done(s);
};
cb_to_use = &wrapper_done;
}
IssueRequest(req_copy ? req_copy : request, response,
&grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
call_opts);
&grpc::WorkerService::Stub::AsyncRecvTensor,
std::move(*cb_to_use), call_opts);
}
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,

View File

@ -349,11 +349,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status PrepareRecvTensor(const string& key, Device** src_dev) {
// Validate the key.
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) {
// Figures out which device the tensor is hosted on.
TF_RETURN_IF_ERROR(
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
@ -375,8 +372,12 @@ class GrpcWorkerService : public AsyncServiceInterface {
const int64 step_id = call->request.step_id();
const string& key = call->request.rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
Status s = PrepareRecvTensor(key, &src_dev);
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
}
if (!s.ok()) {
call->SendResponse(ToGrpcStatus(s));
return;
@ -388,7 +389,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// cancellation should abort the rendezvous.
call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
env_->rendezvous_mgr->RecvLocalAsync(
step_id, key,
step_id, parsed,
[this, call, src_dev](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,

View File

@ -41,8 +41,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
: BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const string& key,
const Rendezvous::ParsedKey& parsed,
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
@ -55,23 +54,49 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
// Used only to retrieve tensors from remote processes.
class RpcRecvTensorCall : public BaseRecvTensorCall {
public:
RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
int64 step_id, const string& key,
const string& remote_dev, Allocator* allocator,
Device* dst_device)
: wi_(wi),
wc_(wc),
remote_dev_(remote_dev),
allocator_(allocator),
dst_(dst_device) {
RpcRecvTensorCall()
: wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id,
StringPiece key, Allocator* allocator, Device* dst_device,
const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
wi_ = wi;
wc_ = wc;
allocator_ = allocator;
dst_device_ = dst_device;
recv_args_ = recv_args;
done_ = std::move(done);
req_.set_step_id(step_id);
req_.set_rendezvous_key(key);
req_.set_rendezvous_key(key.data(), key.size());
}
void Reset() {
delete wi_;
wi_ = nullptr;
wc_ = nullptr;
allocator_ = nullptr;
dst_device_ = nullptr;
// We don't clear opts_ and assume that Init will set up the state for
// opts_ appropriately.
req_.Clear();
if (resp_.ByteSize() > 128) {
// Clear memory from resp_ if it is too large
RecvTensorResponse empty;
resp_.Swap(&empty);
} else {
resp_.Clear();
}
{
mutex_lock l(mu_);
status_ = Status::OK();
}
done_ = nullptr;
}
~RpcRecvTensorCall() override { delete wi_; }
void Start(std::function<void()> recv_done) override {
StartRTCall(recv_done);
StartRTCall(std::move(recv_done));
}
void StartAbort(const Status& s) override {
@ -93,6 +118,10 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
bool is_dead() const { return resp_.is_dead(); }
Device* dst_device() const { return dst_device_; }
const Rendezvous::Args& recv_args() const { return recv_args_; }
const Rendezvous::DoneCallback& done() const { return done_; }
private:
// Start the main RecvTensor call, checking for an async abort.
void StartRTCall(std::function<void()> recv_done) {
@ -100,7 +129,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
nullptr /* TensorBufAllocator */,
// done callback
[this, recv_done](const Status& s) {
{
if (!s.ok()) {
mutex_lock l(mu_);
status_.Update(s);
}
@ -110,12 +139,13 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
WorkerInterface* wi_; // Owned.
WorkerCacheInterface* wc_; // Not owned.
string remote_dev_;
Allocator* allocator_;
Device* dst_;
Device* dst_device_;
CallOptions opts_;
RecvTensorRequest req_;
RecvTensorResponse resp_;
Rendezvous::Args recv_args_;
Rendezvous::DoneCallback done_;
mutable mutex mu_;
Status status_ GUARDED_BY(mu_);
@ -123,10 +153,53 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
};
namespace {
class RpcRecvTensorFreeList {
public:
RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList() {
for (int i = 0; i < objects_.size(); i++) {
delete objects_[i];
}
}
RpcRecvTensorCall* New() {
{
mutex_lock l(mu_);
if (!objects_.empty()) {
RpcRecvTensorCall* result = objects_.back();
objects_.pop_back();
return result;
}
}
return new RpcRecvTensorCall;
}
void Release(RpcRecvTensorCall* obj) {
obj->Reset();
{
mutex_lock l(mu_);
if (objects_.size() < kMaxObjects) {
objects_.push_back(obj);
return;
}
}
delete obj;
}
private:
static const int kMaxObjects = 1000;
mutex mu_;
std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_);
};
static RpcRecvTensorFreeList call_freelist_;
}
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const string& key, const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& recv_args, DoneCallback done) {
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
Status s;
// key.src_device identifies a remote device.
@ -137,11 +210,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
}
// TODO(jeff): Consider checking for a valid worker_cache during the
// constructor of RpcRemoteRendezvous, rather than here, to simplify
// the twisty logic below.
WorkerCacheInterface* worker_cache = env_->worker_cache;
if (s.ok() && worker_cache == nullptr) {
s = errors::Internal("No remote worker cache available.");
}
WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
WorkerInterface* rwi =
(worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", src_worker);
}
@ -157,15 +234,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
// Prepare a RecvTensor call that can handle being aborted.
RpcRecvTensorCall* call =
new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
parsed.src_device, allocator, dst_device);
RpcRecvTensorCall* call = call_freelist_.New();
call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator,
dst_device, recv_args, std::move(done));
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
// Start "call".
call->Start([this, call, parsed, recv_args, done]() {
call->Start([this, call]() {
// Removes "call" from active_. Prevent StartAbort().
DeregisterCall(call);
// If StartAbort was called prior to DeregisterCall, then the
@ -173,24 +251,19 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
Status s = call->status();
Tensor val;
if (s.ok()) {
Device* dst_device;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
if (s.ok()) {
s = dst_device->MakeTensorFromProto(call->tensor_proto(),
recv_args.alloc_attrs, &val);
}
s = call->dst_device()->MakeTensorFromProto(
call->tensor_proto(), call->recv_args().alloc_attrs, &val);
}
done(s, Args(), recv_args, val, call->is_dead());
delete call;
call->done()(s, Args(), call->recv_args(), val, call->is_dead());
call_freelist_.Release(call);
});
}
} // namespace
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env) {
const WorkerEnv* worker_env) {
return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow

View File

@ -40,15 +40,21 @@ string V(const Tensor& tensor) {
return tensor.scalar<string>()();
}
Rendezvous::ParsedKey MakeKey(const string& s) {
Rendezvous::ParsedKey key;
CHECK(Rendezvous::ParseKey(s, &key).ok());
return key;
}
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const string key = Rendezvous::CreateKey(
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
@ -69,9 +75,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) {
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const string key = Rendezvous::CreateKey(
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort().
const int64 step_id = 123;
Rendezvous* rendez = rmgr.Find(step_id);
@ -105,9 +111,9 @@ TEST(RpcRendezvousMgrTest, CleanupAll) {
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const string key = Rendezvous::CreateKey(
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
const int64 step_id = 123;
Rendezvous* rendez = rmgr.Find(step_id);
@ -139,9 +145,9 @@ TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const string key = Rendezvous::CreateKey(
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
namespace Eigen {
@ -80,7 +81,7 @@ class DeviceContext : public core::RefCounted {
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
// to be of the same size as "device_tensor".
virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
const string& tensor_name, Device* device,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) {
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -30,6 +31,21 @@ limitations under the License.
namespace tensorflow {
Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
const char* b_base = b.buf_.data();
buf_ = b.buf_;
src_device.set(buf_.data() + (b.src_device.data() - b_base),
b.src_device.size());
src = b.src;
src_incarnation = b.src_incarnation;
dst_device.set(buf_.data() + (b.dst_device.data() - b_base),
b.dst_device.size());
dst = b.dst;
edge_name.set(buf_.data() + (b.edge_name.data() - b_base),
b.edge_name.size());
return *this;
}
/* static */
string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
const string& dst_device, const string& name,
@ -66,7 +82,8 @@ static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
/* static */
Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
StringPiece s(key);
out->buf_ = key; // Make a copy that our StringPieces can point at
StringPiece s(out->buf_);
StringPiece parts[5];
for (int i = 0; i < 5; i++) {
parts[i] = ConsumeNextPart(&s, ';');
@ -77,9 +94,9 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
!parts[3].empty()) {
out->src_device.assign(parts[0].data(), parts[0].size());
out->dst_device.assign(parts[2].data(), parts[2].size());
out->edge_name.assign(parts[3].data(), parts[3].size());
out->src_device.set(parts[0].data(), parts[0].size());
out->dst_device.set(parts[2].data(), parts[2].size());
out->edge_name.set(parts[3].data(), parts[3].size());
return Status::OK();
}
return errors::InvalidArgument("Invalid rendezvous key: ", key);
@ -87,8 +104,8 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
Rendezvous::~Rendezvous() {}
Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val,
bool* is_dead) {
Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
Tensor* val, bool* is_dead) {
Status ret;
Notification n;
RecvAsync(key, recv_args,
@ -109,18 +126,19 @@ class LocalRendezvousImpl : public Rendezvous {
explicit LocalRendezvousImpl(bool tolerate_dup_recv)
: tolerate_dup_recv_(tolerate_dup_recv) {}
Status Send(const string& key, const Args& send_args, const Tensor& val,
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
const bool is_dead) override {
VLOG(2) << "Send " << this << " " << key;
DoneCallback waiter = nullptr;
Args recv_args;
uint64 key_hash = KeyHash(key.FullKey());
VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
{
mutex_lock l(mu_);
if (!status_.ok()) {
return status_;
}
Item* item = nullptr;
Table::iterator iter = table_.find(key);
Table::iterator iter = table_.find(key_hash);
if (iter == table_.end()) {
// There is no waiter for this message. Insert the message
// into the waiters table. The waiter will pick it up when
@ -138,19 +156,24 @@ class LocalRendezvousImpl : public Rendezvous {
// The allocator attributes of item->value.
item->send_alloc_attrs = send_args.alloc_attrs;
CHECK(table_.insert({key, item}).second);
CHECK(table_.insert({key_hash, item}).second);
return Status::OK();
} else {
item = iter->second;
if (item->waiter == nullptr) {
// There is already a message in the table under the key.
// Should not happen unless it has a waiter.
return errors::Aborted("Duplicated send: ", key);
return errors::Aborted("Duplicated send: ", key.FullKey());
}
// Mark item as complete.
item->has_been_recvd = true;
waiter = item->waiter;
item->waiter = nullptr;
// Get item->waiter function into waiter and set item->waiter to null
std::swap(item->waiter, waiter);
DCHECK(item->waiter == nullptr);
DCHECK(waiter != nullptr);
// The ref on recv_dev_context transfers below.
recv_args.device_context = item->recv_dev_context;
recv_args.alloc_attrs = item->recv_alloc_attrs;
@ -173,9 +196,10 @@ class LocalRendezvousImpl : public Rendezvous {
return Status::OK();
}
void RecvAsync(const string& key, const Args& recv_args,
void RecvAsync(const ParsedKey& key, const Args& recv_args,
DoneCallback done) override {
VLOG(2) << "Recv " << this << " " << key;
uint64 key_hash = KeyHash(key.FullKey());
VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
mu_.lock();
if (!status_.ok()) {
// Rendezvous has been aborted.
@ -184,13 +208,13 @@ class LocalRendezvousImpl : public Rendezvous {
done(s, Args(), recv_args, Tensor(), false);
return;
}
Table::iterator iter = table_.find(key);
Table::iterator iter = table_.find(key_hash);
if (iter != table_.end()) {
Item* item = iter->second;
if (item->has_been_recvd && !tolerate_dup_recv_) {
mu_.unlock();
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
Tensor(), false);
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
recv_args, Tensor(), false);
} else if (item->waiter == nullptr || tolerate_dup_recv_) {
// A message has already arrived and is stored in the table
// under this key. Consumes the message and invokes the done
@ -218,8 +242,8 @@ class LocalRendezvousImpl : public Rendezvous {
// Already have a waiter in the waiters table under this key,
// which should not happen.
mu_.unlock();
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
Tensor(), false);
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
recv_args, Tensor(), false);
}
return;
}
@ -227,13 +251,13 @@ class LocalRendezvousImpl : public Rendezvous {
// waiting table. The done closure will be invoked when the
// message arrives.
Item* item = new Item;
item->waiter = done;
item->waiter = std::move(done);
item->recv_alloc_attrs = recv_args.alloc_attrs;
if (recv_args.device_context) {
item->recv_dev_context = recv_args.device_context;
item->recv_dev_context->Ref();
}
CHECK(table_.insert({key, item}).second);
CHECK(table_.insert({key_hash, item}).second);
mu_.unlock();
return;
}
@ -280,7 +304,12 @@ class LocalRendezvousImpl : public Rendezvous {
}
}
};
typedef std::unordered_map<string, Item*> Table;
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
static uint64 KeyHash(const StringPiece& k) {
return Hash64(k.data(), k.size());
}
typedef std::unordered_map<uint64, Item*> Table;
// TODO(zhifengc): shard table_.
mutex mu_;

View File

@ -54,12 +54,22 @@ class Rendezvous : public core::RefCounted {
// Parses the key constructed by CreateKey and parse src/dst device
// names into structures respectively.
struct ParsedKey {
string src_device;
StringPiece src_device;
DeviceNameUtils::ParsedName src;
uint64 src_incarnation = 0;
string dst_device;
StringPiece dst_device;
DeviceNameUtils::ParsedName dst;
string edge_name;
StringPiece edge_name;
ParsedKey() {}
ParsedKey(const ParsedKey& b) { *this = b; }
ParsedKey& operator=(const ParsedKey& b);
StringPiece FullKey() const { return buf_; }
private:
friend class Rendezvous;
string buf_;
};
static Status ParseKey(const string& key, ParsedKey* out);
@ -74,7 +84,7 @@ class Rendezvous : public core::RefCounted {
// Send/Recv on the same worker.
//
// Send() never blocks.
virtual Status Send(const string& key, const Args& args, const Tensor& val,
virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
const bool is_dead) = 0;
// Callback provided by a tensor consumer waiting on the rendezvous.
@ -84,13 +94,15 @@ class Rendezvous : public core::RefCounted {
// receiver, which may be needed when a non-CPU device is in use
// by either side.
typedef std::function<void(const Status&, const Args&, const Args&,
const Tensor&, const bool)> DoneCallback;
const Tensor&, const bool)>
DoneCallback;
virtual void RecvAsync(const string& key, const Args& args,
virtual void RecvAsync(const ParsedKey& key, const Args& args,
DoneCallback done) = 0;
// Synchronous wrapper for RecvAsync.
Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead);
Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
bool* is_dead);
// Aborts all pending and future Send/Recv with the given "status".
//

View File

@ -96,13 +96,29 @@ string V(const Tensor& tensor) {
return tensor.scalar<string>()();
}
const char* kFoo = "/cpu:0;1;/cpu:1;foo;1;2";
const char* kBar = "/gpu:0;2;/gpu:1;bar;1;2";
Rendezvous::ParsedKey MakeKey(const string& name) {
string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
"/job:mnist/replica:1/task:2/GPU:0", name,
FrameAndIter(0, 0));
Rendezvous::ParsedKey k;
TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
return k;
}
Rendezvous::ParsedKey KeyFoo() { return MakeKey("foo"); }
Rendezvous::ParsedKey KeyBar() { return MakeKey("bar"); }
TEST_F(LocalRendezvousTest, SendRecv) {
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false)));
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
EXPECT_TRUE(
errors::IsAborted(rendez_->Send(KeyFoo(), args, V("hello"), false)));
Tensor val(DT_STRING);
bool is_dead = false;
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
EXPECT_EQ("hello", V(val));
}
@ -110,12 +126,12 @@ TEST_F(LocalRendezvousTest, RecvSend) {
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(10000);
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
});
Tensor val(DT_STRING);
bool is_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
EXPECT_EQ("hello", V(val));
}
@ -124,16 +140,17 @@ TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) {
Tensor t(DT_STRING);
bool is_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
});
Env::Default()->SleepForMicroseconds(1000000);
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
EXPECT_TRUE(
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
EXPECT_EQ("secret msg", V(val));
}
@ -142,17 +159,18 @@ TEST_F(LocalRendezvousTest, DuplicateSerialRecv) {
Tensor t(DT_STRING);
bool is_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
});
Env::Default()->SleepForMicroseconds(1000000);
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
EXPECT_EQ("secret msg", V(val));
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
EXPECT_TRUE(
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
}
// A simple structure that behaves a bit like a blocking counter. The
@ -174,7 +192,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
random::SimplePhilox rnd(&philox);
Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send(strings::StrCat(i), args,
TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
V(strings::StrCat(i)), false));
});
SchedClosure([this, &state, i]() {
@ -184,7 +202,8 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead));
TF_ASSERT_OK(
rendez_->Recv(MakeKey(strings::StrCat(i)), args, &val, &val_dead));
EXPECT_EQ(strings::StrCat(i), V(val));
bool done = false;
{
@ -212,7 +231,7 @@ TEST_F(LocalRendezvousTest, RecvAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
Status status = rendez_->Recv("foo", args, &val, &val_dead);
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
EXPECT_TRUE(errors::IsAborted(status));
}
@ -228,7 +247,7 @@ TEST_F(LocalRendezvousTest, RecvSleepAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
Status status = rendez_->Recv("foo", args, &val, &val_dead);
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
EXPECT_TRUE(errors::IsAborted(status));
}
@ -237,8 +256,9 @@ TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead)));
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
EXPECT_TRUE(
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
}
class DummyDeviceContext : public DeviceContext {
@ -255,15 +275,15 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
Rendezvous::Args args;
args.device_context = new DummyDeviceContext(123);
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
Notification n;
Rendezvous::Args args1;
args1.device_context = new DummyDeviceContext(1);
rendez_->RecvAsync("foo", args1, [&n](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& val, bool is_dead) {
rendez_->RecvAsync(KeyFoo(), args1, [&n](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& val, bool is_dead) {
CHECK_EQ(123,
dynamic_cast<const DummyDeviceContext*>(send_args.device_context)
->stream_id());
@ -284,8 +304,8 @@ static void BM_SendRecv(int iters) {
Status s;
if (iters > 0) {
while (iters--) {
s = rendez->Send("foo", args, orig, is_dead);
s = rendez->Recv("foo", args, &val, &is_dead);
s = rendez->Send(KeyFoo(), args, orig, is_dead);
s = rendez->Recv(KeyFoo(), args, &val, &is_dead);
}
CHECK_EQ(V(val), V(orig));
}
@ -307,8 +327,8 @@ static void BM_RecvSend(int iters) {
Rendezvous::Args args;
Status s;
for (int i = 0; i < iters / 2; ++i) {
s = rendez->Recv("foo", args, &foo, &is_dead);
s = rendez->Send("bar", args, bar, is_dead);
s = rendez->Recv(KeyFoo(), args, &foo, &is_dead);
s = rendez->Send(KeyBar(), args, bar, is_dead);
}
CHECK_EQ("foo", V(foo));
});
@ -318,8 +338,8 @@ static void BM_RecvSend(int iters) {
Rendezvous::Args args;
Status s;
for (int i = 0; i < iters / 2; ++i) {
s = rendez->Send("foo", args, foo, is_dead);
s = rendez->Recv("bar", args, &bar, &is_dead);
s = rendez->Send(KeyFoo(), args, foo, is_dead);
s = rendez->Recv(KeyBar(), args, &bar, &is_dead);
}
CHECK_EQ("bar", V(bar));
delete pool;

View File

@ -118,6 +118,7 @@ cc_library(
cc_library(
name = "fill_functor",
srcs = ["fill_functor.cc"],
hdrs = ["fill_functor.h"],
deps = [
"//tensorflow/core:framework",
@ -1032,7 +1033,6 @@ tf_kernel_libraries(
],
deps = [
":bounds_check",
":constant_op",
":fill_functor",
":transpose_functor",
"//tensorflow/core:core_cpu",
@ -1541,7 +1541,6 @@ tf_kernel_libraries(
],
deps = [
":bounds_check",
":constant_op",
":cwise_op",
":fill_functor",
":scatter_op",
@ -1723,6 +1722,7 @@ filegroup(
"dense_update_ops.cc",
"dense_update_ops.h",
"example_parsing_ops.cc",
"fill_functor.cc",
"fill_functor.h",
"gather_op.cc",
"gather_op.h",

View File

@ -114,37 +114,6 @@ struct FillFunctor<CPUDevice, T> {
}
};
// Partial specialization of SetZeroFunctor<Device=CPUDevice, T>.
template <typename T>
struct SetZeroFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out) {
out.device(d) = out.constant(T(0));
}
};
// Specialization of SetZeroFunctor<Device=CPUDevice, T=string>.
template <>
struct SetZeroFunctor<CPUDevice, string> {
void operator()(const CPUDevice& d, typename TTypes<string>::Flat out) {
out.device(d) = out.constant(string());
}
};
#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor<CPUDevice, T>;
DEFINE_SETZERO_CPU(Eigen::half);
DEFINE_SETZERO_CPU(float);
DEFINE_SETZERO_CPU(double);
DEFINE_SETZERO_CPU(uint8);
DEFINE_SETZERO_CPU(int8);
DEFINE_SETZERO_CPU(uint16);
DEFINE_SETZERO_CPU(int16);
DEFINE_SETZERO_CPU(int32);
DEFINE_SETZERO_CPU(int64);
DEFINE_SETZERO_CPU(complex64);
DEFINE_SETZERO_CPU(complex128);
DEFINE_SETZERO_CPU(string);
#undef DEFINE_SETZERO_CPU
} // end namespace functor
template <typename Device, typename T>

View File

@ -35,7 +35,6 @@ struct scalar_fmod2_op {
return std::fmod(a, b);
}
};
template <typename T>
struct functor_traits<scalar_fmod2_op<T>> {
enum {
@ -44,6 +43,24 @@ struct functor_traits<scalar_fmod2_op<T>> {
};
};
// TODO(rmlarsen): This is a workaround for upstream change
// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
// version of the latter. Remove once we upgrade to Eigen 3.3.
template <typename Scalar, typename Exponent>
struct scalar_binary_pow_op_google {
EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
const Exponent& b) const {
return numext::pow(a, b);
}
};
template <typename Scalar, typename Exponent>
struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
template <typename T, typename DivOrMod>
struct safe_div_or_mod_op {
static_assert(std::is_integral<T>::value, "Integer type expected");
@ -477,7 +494,7 @@ struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
struct maximum : base<T, Eigen::internal::scalar_max_op<T> > {};

View File

@ -61,22 +61,23 @@ class DrawBoundingBoxesOp : public OpKernel {
for (int64 b = 0; b < batch_size; ++b) {
const int64 num_boxes = boxes.dim_size(1);
const auto tboxes = boxes.tensor<T, 3>();
for (int64 bb = 0; bb < num_boxes; ++bb) {
const int64 min_box_row =
(height - 1) * boxes.tensor<float, 3>()(b, bb, 0);
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
const int64 min_box_row_clamp =
std::max<int64>(min_box_row, 0);
const int64 max_box_row =
(height - 1) * boxes.tensor<float, 3>()(b, bb, 2);
static_cast<float>(tboxes(b, bb, 2)) * (height - 1);
const int64 max_box_row_clamp =
std::min<int64>(max_box_row, height - 1);
const int64 min_box_col =
(width - 1) * boxes.tensor<float, 3>()(b, bb, 1);
static_cast<float>(tboxes(b, bb, 1)) * (width - 1);
const int64 min_box_col_clamp =
std::max<int64>(min_box_col, 0);
const int64 max_box_col =
(width - 1) * boxes.tensor<float, 3>()(b, bb, 3);
static_cast<float>(tboxes(b, bb, 3)) * (width - 1);
const int64 max_box_col_clamp =
std::min<int64>(max_box_col, width - 1);
@ -121,22 +122,22 @@ class DrawBoundingBoxesOp : public OpKernel {
// Draw top line.
if (min_box_row >= 0) {
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
canvas(b, min_box_row, j, 0) = T(nanf(""));
canvas(b, min_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
}
// Draw bottom line.
if (max_box_row < height) {
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
canvas(b, max_box_row, j, 0) = T(nanf(""));
canvas(b, max_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
}
// Draw left line.
if (min_box_col >= 0) {
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
canvas(b, i, min_box_col, 0) = T(nanf(""));
canvas(b, i, min_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
}
// Draw right line.
if (max_box_col < width) {
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
canvas(b, i, max_box_col, 0) = T(nanf(""));
canvas(b, i, max_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
}
}
}

View File

@ -0,0 +1,57 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/fill_functor.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace functor {
template <typename T>
void SetZeroFunctor<Eigen::ThreadPoolDevice, T>::operator()(
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
out.device(d) = out.constant(T(0));
}
void SetZeroFunctor<Eigen::ThreadPoolDevice, string>::operator()(
const Eigen::ThreadPoolDevice& d, typename TTypes<string>::Flat out) {
out.device(d) = out.constant(string());
}
// Explicit instantiations.
#define DEFINE_SETZERO_CPU(T) \
template struct SetZeroFunctor<Eigen::ThreadPoolDevice, T>;
DEFINE_SETZERO_CPU(bool);
DEFINE_SETZERO_CPU(Eigen::half);
DEFINE_SETZERO_CPU(float);
DEFINE_SETZERO_CPU(double);
DEFINE_SETZERO_CPU(uint8);
DEFINE_SETZERO_CPU(int8);
DEFINE_SETZERO_CPU(uint16);
DEFINE_SETZERO_CPU(int16);
DEFINE_SETZERO_CPU(int32);
DEFINE_SETZERO_CPU(int64);
DEFINE_SETZERO_CPU(complex64);
DEFINE_SETZERO_CPU(complex128);
DEFINE_SETZERO_CPU(string);
#undef DEFINE_SETZERO_CPU
} // namespace functor
} // namespace tensorflow

View File

@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace functor {
@ -35,6 +38,19 @@ struct SetZeroFunctor {
void operator()(const Device& d, typename TTypes<T>::Flat out);
};
// Partial specialization of SetZeroFunctor<Device=Eigen::ThreadPoolDevice, T>.
template <typename T>
struct SetZeroFunctor<Eigen::ThreadPoolDevice, T> {
void operator()(const Eigen::ThreadPoolDevice& d,
typename TTypes<T>::Flat out);
};
template <>
struct SetZeroFunctor<Eigen::ThreadPoolDevice, string> {
void operator()(const Eigen::ThreadPoolDevice& d,
typename TTypes<string>::Flat out);
};
} // namespace functor
} // namespace tensorflow

View File

@ -64,19 +64,21 @@ TEST_F(RestoreOpTest, RestoreSimple) {
const std::vector<string> tensor_names = {
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64"};
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
"tensor_half"};
// We first need to write a tensor using the save_op
{
// Initialize an operation
NodeDef save;
TF_ASSERT_OK(NodeDefBuilder("myop", "Save")
.Input(FakeInput())
.Input(FakeInput())
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8,
DT_INT16, DT_STRING, DT_COMPLEX64}))
.Finalize(&save));
TF_ASSERT_OK(
NodeDefBuilder("myop", "Save")
.Input(FakeInput())
.Input(FakeInput())
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_STRING,
DT_COMPLEX64, DT_HALF}))
.Finalize(&save));
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -156,7 +158,12 @@ TEST_F(RestoreOpTest, RestoreSimple) {
TensorShape({2, 3}),
[](int x) -> complex64 { return complex64(100 + x, 200 + x); });
inputs.push_back({nullptr, &input_13});
// Input #14 is a 2-d half tensor
Tensor input_14 =
MakeInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
return static_cast<Eigen::half>(x) / Eigen::half(5);
});
inputs.push_back({nullptr, &input_14});
OpKernelContext::Params params;
params.device = device.get();
params.frame_iter = FrameAndIter(0, 0);
@ -321,6 +328,19 @@ TEST_F(RestoreOpTest, RestoreSimple) {
EXPECT_EQ(complex64(100 + i, 200 + i), output->flat<complex64>()(i));
}
}
// The 2-d half tensor
{
MakeRestoreOp(DT_HALF);
(*mutable_input(1).tensor).scalar<string>()() = tensor_names[12];
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
TensorShape expected({2, 4});
EXPECT_TRUE(output->shape().IsSameSize(expected));
for (int i = 0; i < 8; ++i) {
EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(5),
output->flat<Eigen::half>()(i));
}
}
}
class RestoreSliceOpTest : public OpsTestBase {

View File

@ -44,7 +44,7 @@ class SaveOpTest : public OpsTestBase {
.Input(FakeInput())
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64,
DT_STRING, DT_COMPLEX64, DT_COMPLEX128}))
DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF}))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
@ -53,10 +53,10 @@ class SaveOpTest : public OpsTestBase {
TEST_F(SaveOpTest, Simple) {
const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
const string tensornames[] = {
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
"tensor_complex128"};
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
"tensor_complex128", "tensor_half"};
MakeOp();
// Add a file name
@ -64,7 +64,7 @@ TEST_F(SaveOpTest, Simple) {
[&filename](int x) -> string { return filename; });
// Add the tensor names
AddInput<string>(TensorShape({13}),
AddInput<string>(TensorShape({14}),
[&tensornames](int x) -> string { return tensornames[x]; });
// Add a 1-d bool tensor
@ -116,6 +116,10 @@ TEST_F(SaveOpTest, Simple) {
return complex128(100 + x, 200 + x);
});
// Add a 2-d half tensor
AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
return static_cast<Eigen::half>(x) / Eigen::half(2);
});
TF_ASSERT_OK(RunOpKernel());
// Check that the checkpoint file is properly written
@ -363,6 +367,24 @@ TEST_F(SaveOpTest, Simple) {
}
}
{
// The 2-d half tensor
TensorShape shape;
DataType type;
EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type));
TensorShape expected({2, 4});
EXPECT_TRUE(shape.IsSameSize(expected));
EXPECT_EQ(DT_HALF, type);
// We expect the tensor value to be correct.
TensorSlice s = TensorSlice::ParseOrDie("-:-");
Eigen::half data[8];
std::fill_n(data, 8, Eigen::half(0));
EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data));
for (int i = 0; i < 8; ++i) {
EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]);
}
}
}
class SaveSlicesOpTest : public OpsTestBase {

View File

@ -32,10 +32,11 @@ static string GetRendezvousKeyPrefix(const string& send_device,
recv_device, ";", tensor_name);
}
static string GetRendezvousKey(const string& key_prefix,
const FrameAndIter& frame_iter) {
return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
frame_iter.iter_id);
static void GetRendezvousKey(const string& key_prefix,
const FrameAndIter& frame_iter, string* key) {
key->clear();
strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":",
frame_iter.iter_id);
}
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -57,9 +58,13 @@ void SendOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
string key;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
VLOG(2) << "Send " << key;
Rendezvous::ParsedKey parsed;
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(key, &parsed));
// The device context may be passed between the Send/Recv
// boundary, so that the device context used to produce the Tensor
// is used when performing the copy on the recv side (which may be
@ -67,9 +72,8 @@ void SendOp::Compute(OpKernelContext* ctx) {
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->input_alloc_attr(0);
Status s =
ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
ctx->SetStatus(s);
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
ctx->is_input_dead()));
}
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
@ -98,16 +102,21 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
string key;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
VLOG(2) << "Recv " << key;
Rendezvous::ParsedKey parsed;
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(key, &parsed), done);
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->output_alloc_attr(0);
ctx->rendezvous()->RecvAsync(
key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& val, bool is_dead) {
parsed, args,
[ctx, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& val,
bool is_dead) {
ctx->SetStatus(s);
if (s.ok()) {
// 'ctx' allocates the output tensor of the expected type. The

View File

@ -74,15 +74,14 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
int num_threads)
: Eigen::ThreadPoolTempl<EigenEnvironment>(
num_threads, EigenEnvironment(env, thread_options, name)),
num_threads_(num_threads) {}
num_threads, EigenEnvironment(env, thread_options, name)) {}
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
CHECK_GE(total, 0);
CHECK_EQ(total, (int64)(Eigen::Index)total);
Eigen::ThreadPoolDevice device(this, num_threads_);
Eigen::ThreadPoolDevice device(this, this->NumThreads());
device.parallelFor(
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
@ -90,10 +89,6 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
CHECK(0); // should not be used with the old thread pool
#endif
}
int NumThreads() const { return num_threads_; };
const int num_threads_;
};
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
@ -120,5 +115,7 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
} // namespace thread
} // namespace tensorflow

View File

@ -57,6 +57,10 @@ class ThreadPool {
// Returns the number of threads in the pool.
int NumThreads() const;
// Returns current thread id between 0 and NumThreads() - 1, if called from a
// thread in the pool. Returns -1 otherwise.
int CurrentThreadId() const;
struct Impl;
private:

View File

@ -1013,7 +1013,7 @@ Computes softmax activations.
For each batch `i` and class `j` we have
softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
logits: 2-D with shape `[batch_size, num_classes]`.
softmax: Same shape as `logits`.

View File

@ -11755,7 +11755,7 @@ op {
}
}
summary: "Computes softmax activations."
description: "For each batch `i` and class `j` we have\n\n softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))"
description: "For each batch `i` and class `j` we have\n\n softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))"
}
op {
name: "SoftmaxCrossEntropyWithLogits"

View File

@ -101,7 +101,7 @@ def tf_additional_stream_executor_srcs():
return ["platform/default/stream_executor.h"]
def tf_additional_cupti_wrapper_deps():
return [":cupti_wrapper_default"]
return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
def tf_additional_test_deps():
return []

View File

@ -0,0 +1,23 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
)
tf_cuda_library(
name = "cupti_wrapper",
srcs = [
"cupti_wrapper.cc",
],
hdrs = [
"cupti_wrapper.h",
],
copts = tf_copts(),
cuda_deps = [
"//tensorflow/core:stream_executor",
"//third_party/gpus/cuda:cuda_headers",
"//third_party/gpus/cuda:cupti_headers",
],
data = ["//third_party/gpus/cuda:cupti_dsos"],
visibility = ["//visibility:public"],
)

View File

@ -104,22 +104,29 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
}
StringPiece tmp;
while (!fullname.empty()) {
bool progress = false;
if (str_util::ConsumePrefix(&fullname, "/job:")) {
p->has_job = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
return false;
}
} else if (str_util::ConsumePrefix(&fullname, "/replica:")) {
progress = true;
}
if (str_util::ConsumePrefix(&fullname, "/replica:")) {
p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
return false;
}
} else if (str_util::ConsumePrefix(&fullname, "/task:")) {
progress = true;
}
if (str_util::ConsumePrefix(&fullname, "/task:")) {
p->has_task = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
return false;
}
} else if (str_util::ConsumePrefix(&fullname, "/device:")) {
progress = true;
}
if (str_util::ConsumePrefix(&fullname, "/device:")) {
p->has_type = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
return false;
@ -132,24 +139,31 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
return false;
}
}
progress = true;
}
} else if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
p->has_type = true;
p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
return false;
}
} else if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
progress = true;
}
if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
p->has_type = true;
p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
return false;
}
} else {
progress = true;
}
if (!progress) {
return false;
}
}
@ -340,11 +354,22 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
string* device) {
ParsedName pn;
if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
*task = strings::StrCat(
(pn.has_job ? strings::StrCat("/job:", pn.job) : ""),
(pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""),
(pn.has_task ? strings::StrCat("/task:", pn.task) : ""));
*device = strings::StrCat(pn.type, ":", pn.id);
task->clear();
task->reserve(
(pn.has_job ? (5 + pn.job.size()) : 0) +
(pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
(pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
if (pn.has_job) {
strings::StrAppend(task, "/job:", pn.job);
}
if (pn.has_replica) {
strings::StrAppend(task, "/replica:", pn.replica);
}
if (pn.has_task) {
strings::StrAppend(task, "/task:", pn.task);
}
device->clear();
strings::StrAppend(device, pn.type, ":", pn.id);
return true;
}
return false;

View File

@ -87,6 +87,44 @@ struct CopyThatWorksWithStringPointer<string> {
}
};
// Checkpointing of half is done by storing the raw 16 bits as a signed 32bit
// integer. To restore the checkpoint we need to do the reverse operation by
// reinterpreting the integer as a 16 bit float. This prevents us from using
// the default cast operation.
template <>
struct CopyThatWorksWithStringPointer<Eigen::half> {
template <typename SrcTensor, typename DstTensor, typename Shape>
static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
Shape d_start) {
typedef typename SrcTensor::Index Index;
static_assert(kTensorSliceMaxRank == 8,
"If kTensorSliceMaxRank changes, modify the loop below.");
for (Index i0 = 0; i0 < len[0]; i0++) {
for (Index i1 = 0; i1 < len[1]; i1++) {
for (Index i2 = 0; i2 < len[2]; i2++) {
for (Index i3 = 0; i3 < len[3]; i3++) {
for (Index i4 = 0; i4 < len[4]; i4++) {
for (Index i5 = 0; i5 < len[5]; i5++) {
for (Index i6 = 0; i6 < len[6]; i6++) {
for (Index i7 = 0; i7 < len[7]; i7++) {
d(d_start[0] + i0, d_start[1] + i1, d_start[2] + i2,
d_start[3] + i3, d_start[4] + i4, d_start[5] + i5,
d_start[6] + i6, d_start[7] + i7) =
Eigen::internal::raw_uint16_to_half(
s(s_start[0] + i0, s_start[1] + i1, s_start[2] + i2,
s_start[3] + i3, s_start[4] + i4, s_start[5] + i5,
s_start[6] + i6, s_start[7] + i7));
}
}
}
}
}
}
}
}
}
};
// Given a tensor described by "shape", two slices "slice_s" and "slice_d",
// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of
// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of

View File

@ -30,9 +30,9 @@ CORPUS_FILENAME = "europarl-v6.fr-en.en"
MAX_DOC_LENGTH = 10
def training_data(filename):
f = open(filename)
for line in f:
yield line
f = open(filename)
for line in f:
yield line
def iter_docs(docs):
@ -67,34 +67,34 @@ HIDDEN_SIZE = 10
def seq_autoencoder(X, y):
"""Sequence auto-encoder with RNN."""
inputs = learn.ops.one_hot_matrix(X, 256)
in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
"""Sequence auto-encoder with RNN."""
inputs = learn.ops.one_hot_matrix(X, 256)
in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
def get_language_model(hidden_size):
"""Returns a language model with given hidden size."""
"""Returns a language model with given hidden size."""
def language_model(X, y):
inputs = learn.ops.one_hot_matrix(X, 256)
inputs = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs)
target = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, y)
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
return learn.ops.sequence_classifier(output, target)
return language_model
def language_model(X, y):
inputs = learn.ops.one_hot_matrix(X, 256)
inputs = tf.unpack(inputs, axis=1)
target = tf.unpack(y, axis=1)
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
return learn.ops.sequence_classifier(output, target)
return language_model
### Training model.
estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
n_classes=256,
optimizer='Adam', learning_rate=0.01,
steps=1000, batch_size=64, continue_training=True)
estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
n_classes=256, optimizer='Adam',
learning_rate=0.01, steps=1000,
batch_size=64, continue_training=True)
estimator.fit(X, y)

View File

@ -56,7 +56,7 @@ def rnn_model(x, y):
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
word_list = tf.unpack(word_vectors, axis=1)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)

View File

@ -41,7 +41,7 @@ def input_op_fn(x):
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
word_list = tf.unpack(word_vectors, axis=1)
return word_list

View File

@ -47,7 +47,7 @@ def char_rnn_model(x, y):
"""Character level recurrent neural network model to predict classes."""
y = tf.one_hot(y, 15, 1, 0)
byte_list = learn.ops.one_hot_matrix(x, 256)
byte_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list)
byte_list = tf.unpack(byte_list, axis=1)
cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
_, encoding = tf.nn.rnn(cell, byte_list, dtype=tf.float32)

View File

@ -46,47 +46,47 @@ print('Total words: %d' % n_words)
EMBEDDING_SIZE = 50
def average_model(X, y):
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
features = tf.reduce_max(word_vectors, reduction_indices=1)
return learn.models.logistic_regression(features, y)
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
features = tf.reduce_max(word_vectors, reduction_indices=1)
return learn.models.logistic_regression(features, y)
def rnn_model(X, y):
"""Recurrent neural network model to predict from sequence of words
"""Recurrent neural network model to predict from sequence of words
to a class."""
# Convert indexes of words into embeddings.
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
# Create an unrolled Recurrent Neural Networks to length of
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
_, encoding = tf.nn.rnn(cell, word_list, dtype=tf.float32)
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
# neural network of last step) and pass it as features for logistic
# regression over output classes.
return learn.models.logistic_regression(encoding, y)
# Convert indexes of words into embeddings.
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
word_list = tf.unpack(word_vectors, axis=1)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
# Create an unrolled Recurrent Neural Networks to length of
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
_, encoding = tf.nn.rnn(cell, word_list, dtype=tf.float32)
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
# neural network of last step) and pass it as features for logistic
# regression over output classes.
return learn.models.logistic_regression(encoding, y)
model_path = '/tmp/skflow_examples/text_classification'
if os.path.exists(model_path):
classifier = learn.TensorFlowEstimator.restore(model_path)
classifier = learn.TensorFlowEstimator.restore(model_path)
else:
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps
while True:
try:
classifier.fit(X_train, y_train)
except KeyboardInterrupt:
classifier.save(model_path)
break
# Continuously train for 1000 steps
while True:
try:
classifier.fit(X_train, y_train)
except KeyboardInterrupt:
classifier.save(model_path)
break
# Predict on test set
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))

View File

@ -170,6 +170,8 @@ def train():
else: # Record a summary
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
train_writer.add_summary(summary, i)
train_writer.close()
test_writer.close()
def main(_):

View File

@ -3199,7 +3199,7 @@ dist.pmf(counts) # Shape [2]
```
- - -
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, allow_nan=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
Initialize a batch of DirichletMultinomial distributions.
@ -3216,8 +3216,15 @@ Initialize a batch of DirichletMultinomial distributions.
* <b>`allow_arbitrary_counts`</b>: Boolean. This represents whether the pmf/cdf
allows for the `counts` tensor to be non-integral values.
The pmf/cdf are functions that can be evaluated at non-integral values,
but are only a distribution over non-negative integers.
* <b>`strict`</b>: Not used (yet).
but are only a distribution over non-negative integers. If `strict` is
`False`, this assertion is turned off.
* <b>`allow_nan`</b>: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
* <b>`strict`</b>: Whether to assert valid values for parameters `alpha` and `n`, and
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
guaranteed.
* <b>`name`</b>: The name to prefix Ops created by this distribution class.
@ -3233,6 +3240,13 @@ dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
- - -
#### `tf.contrib.distributions.DirichletMultinomial.allow_nan` {#DirichletMultinomial.allow_nan}
Boolean describing behavior when a stat is undefined for batch member.
- - -
#### `tf.contrib.distributions.DirichletMultinomial.alpha` {#DirichletMultinomial.alpha}

View File

@ -40,15 +40,24 @@ Initializes a BaseEstimator instance.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -318,15 +327,24 @@ Constructs an Estimator instance.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -599,15 +617,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -938,15 +965,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -1320,15 +1356,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -1603,15 +1648,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -1859,15 +1913,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -2491,15 +2554,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -2856,15 +2928,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -3139,15 +3220,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -3395,15 +3485,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
@ -4283,15 +4382,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -97,15 +97,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -115,15 +115,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -1,4 +1,4 @@
#### `tf.train.Server.create_local_server(start=True)` {#Server.create_local_server}
#### `tf.train.Server.create_local_server(config=None, start=True)` {#Server.create_local_server}
Creates a new single-process cluster running on the local host.
@ -10,6 +10,8 @@ single-process cluster containing a single task in a job called
##### Args:
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.

View File

@ -25,15 +25,24 @@ Initializes a BaseEstimator instance.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -67,7 +67,7 @@ dist.pmf(counts) # Shape [2]
```
- - -
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, allow_nan=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
Initialize a batch of DirichletMultinomial distributions.
@ -84,8 +84,15 @@ Initialize a batch of DirichletMultinomial distributions.
* <b>`allow_arbitrary_counts`</b>: Boolean. This represents whether the pmf/cdf
allows for the `counts` tensor to be non-integral values.
The pmf/cdf are functions that can be evaluated at non-integral values,
but are only a distribution over non-negative integers.
* <b>`strict`</b>: Not used (yet).
but are only a distribution over non-negative integers. If `strict` is
`False`, this assertion is turned off.
* <b>`allow_nan`</b>: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
* <b>`strict`</b>: Whether to assert valid values for parameters `alpha` and `n`, and
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
guaranteed.
* <b>`name`</b>: The name to prefix Ops created by this distribution class.
@ -101,6 +108,13 @@ dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
- - -
#### `tf.contrib.distributions.DirichletMultinomial.allow_nan` {#DirichletMultinomial.allow_nan}
Boolean describing behavior when a stat is undefined for batch member.
- - -
#### `tf.contrib.distributions.DirichletMultinomial.alpha` {#DirichletMultinomial.alpha}

View File

@ -42,15 +42,24 @@ Constructs an Estimator instance.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -9,7 +9,7 @@ communicate with any other server in the same cluster.
- - -
#### `tf.train.Server.__init__(server_or_cluster_def, job_name=None, task_index=None, protocol=None, start=True)` {#Server.__init__}
#### `tf.train.Server.__init__(server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True)` {#Server.__init__}
Creates a new server with the given definition.
@ -32,6 +32,8 @@ override any information provided in `server_or_cluster_def`.
* <b>`protocol`</b>: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc"`. Defaults to the value in
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server
after creating it. Defaults to `True`.
@ -43,7 +45,7 @@ override any information provided in `server_or_cluster_def`.
- - -
#### `tf.train.Server.create_local_server(start=True)` {#Server.create_local_server}
#### `tf.train.Server.create_local_server(config=None, start=True)` {#Server.create_local_server}
Creates a new single-process cluster running on the local host.
@ -55,6 +57,8 @@ single-process cluster containing a single task in a job called
##### Args:
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
@ -84,6 +88,18 @@ with tf.Session(server.target):
A string containing a session target for this server.
- - -
#### `tf.train.Server.server_def` {#Server.server_def}
Returns the `tf.train.ServerDef` for this server.
##### Returns:
A `tf.train.ServerDef` prototocol buffer that describes the configuration
of this server.
- - -

View File

@ -116,15 +116,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -4,7 +4,7 @@ Computes softmax activations.
For each batch `i` and class `j` we have
softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
##### Args:

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -114,15 +114,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

View File

@ -33,15 +33,24 @@ Returns weights of deep neural network part.
Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
##### Args:
* <b>`x`</b>: features.
* <b>`y`</b>: targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`. If `steps` is `None`, the tensors returned by this should
generally raise an end-of-input exception when all eval records have
been returned (typically, 1 epoch over eval data).
`None`.
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
once per iteration.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first

Some files were not shown because too many files have changed in this diff Show More