Merge pull request #3082 from rmlarsen/branch_126082003
Branch 126082003
This commit is contained in:
commit
cee7cdd23d
@ -1,7 +1,6 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
archive_dir = "eigen-eigen-802d984ade26"
|
||||
|
||||
archive_dir = "eigen-eigen-334b1d428283"
|
||||
cc_library(
|
||||
name = "eigen",
|
||||
hdrs = glob([archive_dir+"/**/*.h", archive_dir+"/unsupported/Eigen/CXX11/*", archive_dir+"/Eigen/*"]),
|
||||
|
@ -8,7 +8,7 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "bayesflow_py",
|
||||
@ -16,7 +16,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
cuda_py_test(
|
||||
name = "stochastic_graph_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
|
||||
@ -27,6 +27,17 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "reinforce_simple_example",
|
||||
size = "small",
|
||||
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -0,0 +1,143 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Simple examples of the REINFORCE algorithm."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
sg = tf.contrib.bayesflow.stochastic_graph
|
||||
|
||||
|
||||
def split_apply_merge(inp, partitions, fns):
|
||||
"""Split input according to partitions. Pass results through fns and merge.
|
||||
|
||||
Args:
|
||||
inp: the input vector
|
||||
partitions: tensor of same length as input vector, having values 0, 1
|
||||
fns: the two functions.
|
||||
|
||||
Returns:
|
||||
the vector routed, where routed[i] = fns[partitions[i]](inp[i])
|
||||
"""
|
||||
new_inputs = tf.dynamic_partition(inp, partitions, len(fns))
|
||||
new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)]
|
||||
new_indices = tf.dynamic_partition(
|
||||
tf.range(0, inp.get_shape()[0]), partitions, len(fns))
|
||||
return tf.dynamic_stitch(new_indices, new_outputs)
|
||||
|
||||
|
||||
def plus_1(inputs):
|
||||
return inputs + 1.0
|
||||
|
||||
|
||||
def minus_1(inputs):
|
||||
return inputs - 1.0
|
||||
|
||||
|
||||
def build_split_apply_merge_model():
|
||||
"""Build the Split-Apply-Merge Model.
|
||||
|
||||
Route each value of input [-1, -1, 1, 1] through one of the
|
||||
functions, plus_1, minus_1. The decision for routing is made by
|
||||
4 Bernoulli R.V.s whose parameters are determined by a neural network
|
||||
applied to the input. REINFORCE is used to update the NN parameters.
|
||||
|
||||
Returns:
|
||||
The 3-tuple (route_selection, routing_loss, final_loss), where:
|
||||
|
||||
- route_selection is an int 4-vector
|
||||
- routing_loss is a float 4-vector
|
||||
- final_loss is a float scalar.
|
||||
"""
|
||||
inputs = tf.constant([[-1.0], [-1.0], [1.0], [1.0]])
|
||||
targets = tf.constant([[0.0], [0.0], [0.0], [0.0]])
|
||||
paths = [plus_1, minus_1]
|
||||
weights = tf.get_variable("w", [1, 2])
|
||||
bias = tf.get_variable("b", [1, 1])
|
||||
logits = tf.matmul(inputs, weights) + bias
|
||||
|
||||
# REINFORCE forward step
|
||||
route_selection = sg.DistributionTensor(
|
||||
distributions.Categorical, logits=logits)
|
||||
|
||||
# Accessing route_selection as a Tensor below forces a sample of
|
||||
# the Categorical distribution based on its logits.
|
||||
# This is equivalent to calling route_selection.value().
|
||||
#
|
||||
# route_selection.value() returns an int32 4-vector with random
|
||||
# values in {0, 1}
|
||||
# COPY+ROUTE+PASTE
|
||||
outputs = split_apply_merge(inputs, route_selection, paths)
|
||||
|
||||
# flatten routing_loss to a row vector (from a column vector)
|
||||
routing_loss = tf.reshape(tf.square(outputs - targets), shape=[-1])
|
||||
|
||||
# returns
|
||||
# [stop_gradient(routing_loss) *
|
||||
# route_selection.log_pmf(stop_gradients(route_selection.value()))],
|
||||
# where log_pmf has gradients going all the way back to weights and bias.
|
||||
|
||||
# REINFORCE loss
|
||||
score_function_losses = sg.surrogate_losses([routing_loss])
|
||||
|
||||
# calculate the entire loss:
|
||||
# routing_loss, and the score function loss.
|
||||
# in this case, the routing_loss depends on the variables only through
|
||||
# "route_selection", which has a stop_gradients on it. so the
|
||||
# gradient of the loss really come through the score function
|
||||
all_loss = score_function_losses + [routing_loss]
|
||||
final_loss = tf.reduce_sum(tf.add_n(all_loss))
|
||||
|
||||
return (route_selection, routing_loss, final_loss)
|
||||
|
||||
|
||||
class REINFORCESimpleExample(tf.test.TestCase):
|
||||
|
||||
def testSplitApplyMerge(self):
|
||||
# Repeatability. SGD has a tendency to jump around, even here.
|
||||
tf.set_random_seed(1)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Use sampling to train REINFORCE
|
||||
with sg.value_type(sg.SampleAndReshapeValue(n=1)):
|
||||
(route_selection,
|
||||
routing_loss,
|
||||
final_loss) = build_split_apply_merge_model()
|
||||
|
||||
sgd = tf.train.GradientDescentOptimizer(1.0).minimize(final_loss)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
for i in range(10):
|
||||
# Run loss and inference step. This toy problem converges VERY quickly.
|
||||
(routing_loss_v, final_loss_v, route_selection_v, _) = sess.run(
|
||||
[routing_loss, final_loss, tf.identity(route_selection), sgd])
|
||||
print(
|
||||
"Iteration %d, routing loss: %s, final_loss: %s, "
|
||||
"route selection: %s"
|
||||
% (i, routing_loss_v, final_loss_v, route_selection_v))
|
||||
|
||||
self.assertAllEqual([0, 0, 1, 1], route_selection_v)
|
||||
self.assertAllClose([0.0, 0.0, 0.0, 0.0], routing_loss_v)
|
||||
self.assertAllClose(0.0, final_loss_v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -265,6 +265,16 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
||||
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
|
||||
self.assertEqual((), pmf_same.get_shape())
|
||||
|
||||
def testNonStrictTurnsOffAllChecks(self):
|
||||
# Make totally invalid input.
|
||||
with self.test_session():
|
||||
alpha = [[-1., 2]] # alpha should be positive.
|
||||
counts = [[1., 0], [0., -1]] # counts should be non-negative.
|
||||
n = [-5.3] # n should be a non negative integer equal to counts.sum.
|
||||
dist = tf.contrib.distributions.DirichletMultinomial(
|
||||
n, alpha, strict=False)
|
||||
dist.pmf(counts).eval() # Should not raise.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
@ -38,21 +38,6 @@ def _assert_integer_form(x):
|
||||
math_ops.round(casted_x), x.dtype))
|
||||
|
||||
|
||||
def _check_alpha(alpha):
|
||||
"""Check alpha for proper shape, values, then return tensor version."""
|
||||
alpha = ops.convert_to_tensor(alpha, name='alpha_before_deps')
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_rank_at_least(alpha, 1),
|
||||
check_ops.assert_positive(alpha)], alpha)
|
||||
|
||||
|
||||
def _check_n(n):
|
||||
"""Check n for proper shape, values, then return tensor version."""
|
||||
n = ops.convert_to_tensor(n, name='n_before_deps')
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
|
||||
|
||||
|
||||
def _log_combinations(n, counts, name='log_combinations'):
|
||||
"""Log number of ways counts could have come in."""
|
||||
# First a bit about the number of ways counts could have come in:
|
||||
@ -148,6 +133,7 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
n,
|
||||
alpha,
|
||||
allow_arbitrary_counts=False,
|
||||
allow_nan=False,
|
||||
strict=True,
|
||||
name='DirichletMultinomial'):
|
||||
"""Initialize a batch of DirichletMultinomial distributions.
|
||||
@ -163,8 +149,15 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf
|
||||
allows for the `counts` tensor to be non-integral values.
|
||||
The pmf/cdf are functions that can be evaluated at non-integral values,
|
||||
but are only a distribution over non-negative integers.
|
||||
strict: Not used (yet).
|
||||
but are only a distribution over non-negative integers. If `strict` is
|
||||
`False`, this assertion is turned off.
|
||||
allow_nan: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
strict: Whether to assert valid values for parameters `alpha` and `n`, and
|
||||
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
|
||||
guaranteed.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
@ -178,9 +171,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
```
|
||||
"""
|
||||
# TODO(langmore): Should strict supercede allow_arbitrary_counts?
|
||||
# Or work orthogonal to it? Implement correct usage of strict=True
|
||||
self._allow_nan = allow_nan
|
||||
self._strict = strict
|
||||
self._name = name
|
||||
self._allow_arbitrary_counts = allow_arbitrary_counts
|
||||
with ops.op_scope([n, alpha], name):
|
||||
# Broadcasting works because:
|
||||
@ -192,12 +185,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
# explicitivity.
|
||||
# * All calls involving `counts` eventually require a broadcast between
|
||||
# `counts` and alpha.
|
||||
self._alpha = _check_alpha(alpha)
|
||||
self._name = name
|
||||
|
||||
n = _check_n(n)
|
||||
n = math_ops.cast(n, self._alpha.dtype)
|
||||
self._n = n
|
||||
self._alpha = self._check_alpha(alpha)
|
||||
n = self._check_n(n)
|
||||
self._n = math_ops.cast(n, self._alpha.dtype)
|
||||
|
||||
self._alpha_sum = math_ops.reduce_sum(
|
||||
self._alpha, reduction_indices=[-1], keep_dims=False)
|
||||
@ -217,6 +207,11 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
"""Parameter defining this distribution."""
|
||||
return self._alpha
|
||||
|
||||
@property
|
||||
def allow_nan(self):
|
||||
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||
return self._allow_nan
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
"""Boolean describing behavior on invalid input."""
|
||||
@ -368,7 +363,9 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
|
||||
def _check_counts(self, counts):
|
||||
"""Check counts for proper shape, values, then return tensor version."""
|
||||
counts = ops.convert_to_tensor(counts, name='counts_before_deps')
|
||||
counts = ops.convert_to_tensor(counts, name='counts')
|
||||
if not self.strict:
|
||||
return counts
|
||||
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||
dependencies = [check_ops.assert_non_negative(counts),
|
||||
check_ops.assert_equal(self._n,
|
||||
@ -378,3 +375,18 @@ class DirichletMultinomial(distribution.DiscreteDistribution):
|
||||
dependencies += [_assert_integer_form(counts)]
|
||||
|
||||
return control_flow_ops.with_dependencies(dependencies, counts)
|
||||
|
||||
def _check_alpha(self, alpha):
|
||||
alpha = ops.convert_to_tensor(alpha, name='alpha')
|
||||
if not self.strict:
|
||||
return alpha
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_rank_at_least(alpha, 1),
|
||||
check_ops.assert_positive(alpha)], alpha)
|
||||
|
||||
def _check_n(self, n):
|
||||
n = ops.convert_to_tensor(n, name='n')
|
||||
if not self.strict:
|
||||
return n
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
|
||||
|
@ -51,11 +51,11 @@ def stratified_sample(data, labels, probs, batch_size,
|
||||
threads_per_queue: Number of threads for each per-class queue.
|
||||
name: Optional prefix for ops created by this function.
|
||||
Raises:
|
||||
AssertionError: enqueue_many is True and labels doesn't have a batch
|
||||
ValueError: enqueue_many is True and labels doesn't have a batch
|
||||
dimension, or if enqueue_many is False and labels isn't a scalar.
|
||||
AssertionError: enqueue_many is True, and batch dimension on data and labels
|
||||
ValueError: enqueue_many is True, and batch dimension of data and labels
|
||||
don't match.
|
||||
AssertionError: if probs don't sum to one.
|
||||
ValueError: if probs don't sum to one.
|
||||
TFAssertion: if labels aren't integers in [0, num classes).
|
||||
Returns:
|
||||
(data_batch, label_batch)
|
||||
@ -81,11 +81,11 @@ def stratified_sample(data, labels, probs, batch_size,
|
||||
labels = array_ops.expand_dims(labels, 0)
|
||||
|
||||
# Validate that input is consistent.
|
||||
data, labels = _verify_input(data, labels, probs)
|
||||
data, labels, probs = _verify_input(data, labels, probs)
|
||||
|
||||
# Make per-class queues.
|
||||
per_class_queues = _make_per_class_queues(
|
||||
data, labels, len(probs), queue_capacity, threads_per_queue)
|
||||
data, labels, probs.size, queue_capacity, threads_per_queue)
|
||||
|
||||
# Use the per-class queues to generate stratified batches.
|
||||
return _get_batch(per_class_queues, probs, batch_size)
|
||||
@ -93,15 +93,19 @@ def stratified_sample(data, labels, probs, batch_size,
|
||||
|
||||
def _verify_input(data, labels, probs):
|
||||
"""Verify that batched inputs are well-formed."""
|
||||
# Probabilities must be a numpy array or a Python list.
|
||||
if not (isinstance(probs, np.ndarray) or isinstance(probs, list)):
|
||||
raise ValueError('Probabilities must be python or numpy array')
|
||||
# Probabilities must be able to be converted to a 1D non-object numpy array.
|
||||
probs = np.asarray(probs)
|
||||
if probs.dtype == np.dtype('object'):
|
||||
raise ValueError('Probabilities must be able to be converted to a numpy '
|
||||
'array.')
|
||||
if len(probs.shape) != 1:
|
||||
raise ValueError('Probabilities must be 1D.')
|
||||
|
||||
# Probabilities must sum to one.
|
||||
# TODO(joelshor): Investigate whether logits should be passed instead of
|
||||
# probs.
|
||||
if np.sum(probs) != 1.0:
|
||||
raise ValueError('Probabilities must sum to one.')
|
||||
if not np.isclose(np.sum(probs), 1.0):
|
||||
raise ValueError('Probabilities must sum to one.', np.sum(probs))
|
||||
|
||||
# Labels tensor should only have batch dimension.
|
||||
labels.get_shape().assert_has_rank(1)
|
||||
@ -129,7 +133,7 @@ def _verify_input(data, labels, probs):
|
||||
check_ops.assert_less(labels, math_ops.cast(len(probs), labels.dtype))],
|
||||
labels)
|
||||
|
||||
return data, labels
|
||||
return data, labels, probs
|
||||
|
||||
|
||||
def _make_per_class_queues(data, labels, num_classes, queue_capacity,
|
||||
@ -161,12 +165,12 @@ def _make_per_class_queues(data, labels, num_classes, queue_capacity,
|
||||
|
||||
def _get_batch(per_class_queues, probs, batch_size):
|
||||
"""Generates batches according to per-class-probabilities."""
|
||||
num_classes = len(probs)
|
||||
num_classes = probs.size
|
||||
# Number of examples per class is governed by a multinomial distribution.
|
||||
# Note: multinomial takes unnormalized log probabilities for its first
|
||||
# argument.
|
||||
# argument, of dimension [batch_size, num_classes].
|
||||
examples = random_ops.multinomial(
|
||||
math_ops.log([[float(x) for x in probs]]), batch_size)
|
||||
np.expand_dims(np.log(probs), 0), batch_size)
|
||||
|
||||
# Prepare the data and label batches.
|
||||
val_list = []
|
||||
|
@ -63,6 +63,11 @@ class SamplingOpsTest(tf.test.TestCase):
|
||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||
val, label, np.array([.1] * 5), batch_size)
|
||||
|
||||
# Probabilities must be 1D.
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||
val, label, np.array([[.25, .25], [.25, .25]]), batch_size)
|
||||
|
||||
def testRuntimeAssertionFailures(self):
|
||||
probs = [.2] * 5
|
||||
vals = tf.zeros([3, 1])
|
||||
@ -75,14 +80,14 @@ class SamplingOpsTest(tf.test.TestCase):
|
||||
|
||||
# Set up graph with illegal label vector.
|
||||
label_ph = tf.placeholder(tf.int32, shape=[None])
|
||||
batch_tf = tf.contrib.framework.sampling_ops._verify_input(
|
||||
vals_tf, lbls_tf, _ = tf.contrib.framework.sampling_ops._verify_input(
|
||||
vals, label_ph, probs)
|
||||
|
||||
for illegal_label in illegal_labels:
|
||||
# Run session that should fail.
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
sess.run(batch_tf, feed_dict={label_ph: illegal_label})
|
||||
sess.run([vals_tf, lbls_tf], feed_dict={label_ph: illegal_label})
|
||||
|
||||
def testBatchingBehavior(self):
|
||||
batch_size = 20
|
||||
@ -119,13 +124,14 @@ class SamplingOpsTest(tf.test.TestCase):
|
||||
# Set up graph with placeholders.
|
||||
vals_ph = tf.placeholder(tf.float32) # completely undefined shape
|
||||
labels_ph = tf.placeholder(tf.int32) # completely undefined shape
|
||||
batch_tf = tf.contrib.framework.sampling_ops._verify_input(
|
||||
vals_tf, lbls_tf, _ = tf.contrib.framework.sampling_ops._verify_input(
|
||||
vals_ph, labels_ph, probs)
|
||||
|
||||
# Run graph to make sure there are no shape-related runtime errors.
|
||||
for vals, labels in legal_input_pairs:
|
||||
with self.test_session() as sess:
|
||||
sess.run(batch_tf, feed_dict={vals_ph: vals, labels_ph: labels})
|
||||
sess.run([vals_tf, lbls_tf], feed_dict={vals_ph: vals,
|
||||
labels_ph: labels})
|
||||
|
||||
def testNormalBehavior(self):
|
||||
# Set up graph.
|
||||
|
@ -18,6 +18,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/learn/python/learn/datasets",
|
||||
"//tensorflow/contrib/session_bundle:exporter",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
@ -220,6 +221,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "compare_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/dataframe/compare_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "early_stopping_test",
|
||||
size = "medium",
|
||||
@ -562,6 +575,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "export_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/utils/export_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "stability_test",
|
||||
size = "small",
|
||||
|
@ -32,5 +32,17 @@ from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.sum import Sum
|
||||
|
||||
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
||||
|
||||
# Unary Transform registration
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import unary_transforms as _ut
|
||||
for ut_def in _ut.UNARY_TRANSFORMS:
|
||||
_ut.register_unary_op(*ut_def)
|
||||
|
||||
# Comparison Transform registration
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import compare as _cmp
|
||||
for ct_def in _cmp.COMPARISON_TRANSFORMS:
|
||||
_cmp.register_comparison_ops(*ct_def)
|
||||
|
||||
__all__ = ['DataFrame', 'Series', 'TransformedSeries', 'TensorFlowDataFrame',
|
||||
'parameter', 'Transform']
|
||||
|
@ -0,0 +1,156 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Transforms for comparing pairs of `Series`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# Each entry is a mapping from registered_name to operation. Each operation is
|
||||
# wrapped in a transform and then registered as a member function
|
||||
# `Series`.registered_name().
|
||||
COMPARISON_TRANSFORMS = [("__eq__", math_ops.equal),
|
||||
("__gt__", math_ops.greater),
|
||||
("__ge__", math_ops.greater_equal),
|
||||
("__lt__", math_ops.less),
|
||||
("__le__", math_ops.less_equal)]
|
||||
|
||||
SERIES_DOC_FORMAT_STRING = (
|
||||
"A `Transform` that uses `{0}` to compare two Series. "
|
||||
"Documentation for `{0}`: \n\n {1}"
|
||||
)
|
||||
|
||||
SCALAR_DOC_FORMAT_STRING = (
|
||||
"A `Transform` that uses `{0}` to compare a Series and a scalar. "
|
||||
"Documentation for `{0}`: \n\n {1}"
|
||||
)
|
||||
|
||||
|
||||
class SeriesComparisonTransform(transform.Transform):
|
||||
"""Parent class for `Transform`s that compare `Series` elementwise."""
|
||||
|
||||
@property
|
||||
def input_valency(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
# TODO(jamieas): consider supporting sparse comparisons.
|
||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
||||
input_tensors[1], ops.SparseTensor):
|
||||
raise TypeError("{} does not support SparseTensors".format(type(
|
||||
self).__name__))
|
||||
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(self._compare(input_tensors[0], input_tensors[1]))
|
||||
|
||||
|
||||
class ScalarComparisonTransform(transform.Transform):
|
||||
"""Parent class for `Transform`s that compare `Series` to a scalar."""
|
||||
|
||||
def __init__(self, threshold):
|
||||
if isinstance(threshold, series.Series):
|
||||
raise ValueError(
|
||||
"{} is used to compare Series with scalars. It was called with "
|
||||
"another Series.".format(
|
||||
type(self).__name__))
|
||||
super(ScalarComparisonTransform, self).__init__()
|
||||
self._threshold = threshold
|
||||
|
||||
@transform.parameter
|
||||
def threshold(self):
|
||||
return self._threshold
|
||||
|
||||
@property
|
||||
def input_valency(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
input_tensor = input_tensors[0]
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
result = ops.SparseTensor(input_tensor.indices,
|
||||
self._compare(input_tensor.values),
|
||||
input_tensor.shape)
|
||||
else:
|
||||
result = self._compare(input_tensor)
|
||||
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(result)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def register_comparison_ops(method_name, operation):
|
||||
"""Registers `Series` member functions for comparisons.
|
||||
|
||||
Args:
|
||||
method_name: the name of the method that will be created in `Series`.
|
||||
operation: TensorFlow operation used for comparison.
|
||||
"""
|
||||
|
||||
# Define series-series comparison `Transform`.
|
||||
@property
|
||||
def series_name(self):
|
||||
return operation.__name__
|
||||
|
||||
series_doc = SERIES_DOC_FORMAT_STRING.format(operation.__name__,
|
||||
operation.__doc__)
|
||||
def series_compare(self, x, y):
|
||||
return operation(x, y)
|
||||
|
||||
series_transform_cls = type("scalar_{}".format(operation.__name__),
|
||||
(SeriesComparisonTransform,),
|
||||
{"name": series_name,
|
||||
"__doc__": series_doc,
|
||||
"_compare": series_compare})
|
||||
|
||||
# Define series-scalar comparison `Transform`.
|
||||
@property
|
||||
def scalar_name(self):
|
||||
return "scalar_{}".format(operation.__name__)
|
||||
|
||||
scalar_doc = SCALAR_DOC_FORMAT_STRING.format(operation.__name__,
|
||||
operation.__doc__)
|
||||
|
||||
def scalar_compare(self, x):
|
||||
return operation(x, self.threshold)
|
||||
|
||||
scalar_transform_cls = type("scalar_{}".format(operation.__name__),
|
||||
(ScalarComparisonTransform,),
|
||||
{"name": scalar_name,
|
||||
"__doc__": scalar_doc,
|
||||
"_compare": scalar_compare})
|
||||
|
||||
# Define function that delegates to the two `Transforms`.
|
||||
def _comparison_fn(self, other, *args, **kwargs):
|
||||
# pylint: disable=not-callable,abstract-class-instantiated
|
||||
if isinstance(other, series.Series):
|
||||
return series_transform_cls(*args, **kwargs)([self, other])[0]
|
||||
return scalar_transform_cls(other, *args, **kwargs)([self])[0]
|
||||
|
||||
# Register new member function of `Series`.
|
||||
setattr(series.Series, method_name, _comparison_fn)
|
@ -52,9 +52,7 @@ DOC_FORMAT_STRING = (
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def _register_unary_op(registered_name,
|
||||
operation):
|
||||
|
||||
def register_unary_op(registered_name, operation):
|
||||
"""Creates a `Transform` that wraps a unary tensorflow operation.
|
||||
|
||||
If `registered_name` is specified, the `Transform` is registered as a member
|
||||
@ -100,7 +98,3 @@ def _register_unary_op(registered_name,
|
||||
"_apply_transform": _apply_transform})
|
||||
|
||||
series.Series.register_unary_op(registered_name)(cls)
|
||||
|
||||
|
||||
for ut in UNARY_TRANSFORMS:
|
||||
_register_unary_op(*ut)
|
||||
|
@ -34,7 +34,6 @@ from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -47,6 +46,238 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
class _ComposableModel(object):
|
||||
"""ABC for building blocks that can be used to create estimators.
|
||||
|
||||
Subclasses need to implement the following methods:
|
||||
- build_model
|
||||
- _get_optimizer
|
||||
See below for the required signatures.
|
||||
_ComposableModel and its subclasses are not part of the public tf.learn API.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_label_columns,
|
||||
optimizer,
|
||||
weight_collection_name,
|
||||
gradient_clip_norm):
|
||||
"""Common initialization for all _ComposableModel objects.
|
||||
|
||||
Args:
|
||||
num_label_columns: The number of label/target columns.
|
||||
optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||
the model. If `None`, will use a FTRL optimizer.
|
||||
weight_collection_name: A string defining the name to use for the
|
||||
collection of weights (e.g. 'dnn').
|
||||
gradient_clip_norm: A float > 0. If provided, gradients are clipped
|
||||
to their global norm with this clipping ratio. See
|
||||
tf.clip_by_global_norm for more details.
|
||||
"""
|
||||
self._num_label_columns = num_label_columns
|
||||
self._optimizer = optimizer
|
||||
self._weight_collection_name = weight_collection_name
|
||||
self._gradient_clip_norm = gradient_clip_norm
|
||||
self._feature_columns=None
|
||||
|
||||
def build_model(self, features, feature_columns, is_training):
|
||||
"""Builds the model that can calculate the logits.
|
||||
|
||||
Args:
|
||||
features: A mapping from feature columns to tensors.
|
||||
feature_columns: An iterable containing all the feature columns used
|
||||
by the model. All items in the set should be instances of
|
||||
classes derived from `FeatureColumn`.
|
||||
is_training: Set to True when training, False otherwise.
|
||||
|
||||
Returns:
|
||||
The logits for this model.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_train_step(self, loss):
|
||||
"""Returns the ops to run to perform a training step on this estimator.
|
||||
|
||||
Args:
|
||||
loss: The loss to use when calculating gradients.
|
||||
|
||||
Returns:
|
||||
The ops to run to perform a training step.
|
||||
"""
|
||||
my_vars = self._get_vars()
|
||||
if not (self._get_feature_columns() or my_vars):
|
||||
return []
|
||||
|
||||
grads = gradients.gradients(loss, my_vars)
|
||||
if self._gradient_clip_norm:
|
||||
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
|
||||
self._optimizer = self._get_optimizer()
|
||||
return [self._optimizer.apply_gradients(zip(grads, my_vars))]
|
||||
|
||||
def _get_feature_columns(self):
|
||||
if not self._feature_columns:
|
||||
return None
|
||||
feature_column_ops.check_feature_columns(self._feature_columns)
|
||||
return sorted(set(self._feature_columns), key=lambda x: x.key)
|
||||
|
||||
def _get_feature_dict(self, features):
|
||||
if isinstance(features, dict):
|
||||
return features
|
||||
return {"": features}
|
||||
|
||||
def _get_vars(self):
|
||||
if self._get_feature_columns():
|
||||
return ops.get_collection(self._weight_collection_name)
|
||||
return []
|
||||
|
||||
def _get_optimizer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _LinearComposableModel(_ComposableModel):
|
||||
"""A _ComposableModel that implements linear regression.
|
||||
|
||||
Instances of this class can be used to build estimators through the use
|
||||
of composition.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_label_columns,
|
||||
optimizer=None,
|
||||
gradient_clip_norm=None):
|
||||
"""Initializes _LinearComposableModel objects.
|
||||
|
||||
Args:
|
||||
num_label_columns: The number of label/target columns.
|
||||
optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||
the model. If `None`, will use a FTRL optimizer.
|
||||
gradient_clip_norm: A float > 0. If provided, gradients are clipped
|
||||
to their global norm with this clipping ratio. See
|
||||
tf.clip_by_global_norm for more details.
|
||||
"""
|
||||
super(_LinearComposableModel, self).__init__(
|
||||
num_label_columns=num_label_columns,
|
||||
optimizer=optimizer,
|
||||
weight_collection_name="linear",
|
||||
gradient_clip_norm=gradient_clip_norm)
|
||||
|
||||
def build_model(self, features, feature_columns, is_training):
|
||||
"""See base class."""
|
||||
features = self._get_feature_dict(features)
|
||||
self._feature_columns = feature_columns
|
||||
|
||||
logits, _, _ = layers.weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=self._get_feature_columns(),
|
||||
num_outputs=self._num_label_columns,
|
||||
weight_collections=[self._weight_collection_name],
|
||||
name="linear")
|
||||
return logits
|
||||
|
||||
def _get_optimizer(self):
|
||||
if self._optimizer is None:
|
||||
self._optimizer = "Ftrl"
|
||||
if isinstance(self._optimizer, six.string_types):
|
||||
default_learning_rate = 1. / math.sqrt(len(self._get_feature_columns()))
|
||||
self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
|
||||
learning_rate=default_learning_rate)
|
||||
return self._optimizer
|
||||
|
||||
|
||||
class _DNNComposableModel(_ComposableModel):
|
||||
"""A _ComposableModel that implements a DNN.
|
||||
|
||||
Instances of this class can be used to build estimators through the use
|
||||
of composition.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_label_columns,
|
||||
hidden_units,
|
||||
optimizer=None,
|
||||
activation_fn=nn.relu,
|
||||
dropout=None,
|
||||
gradient_clip_norm=None,
|
||||
config=None):
|
||||
"""Initializes _DNNComposableModel objects.
|
||||
|
||||
Args:
|
||||
num_label_columns: The number of label/target columns.
|
||||
hidden_units: List of hidden units per layer. All layers are fully
|
||||
connected.
|
||||
optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||
the model. If `None`, will use a FTRL optimizer.
|
||||
activation_fn: Activation function applied to each layer. If `None`,
|
||||
will use `tf.nn.relu`.
|
||||
dropout: When not None, the probability we will drop out
|
||||
a given coordinate.
|
||||
gradient_clip_norm: A float > 0. If provided, gradients are clipped
|
||||
to their global norm with this clipping ratio. See
|
||||
tf.clip_by_global_norm for more details.
|
||||
config: RunConfig object to configure the runtime settings.
|
||||
"""
|
||||
super(_DNNComposableModel, self).__init__(
|
||||
num_label_columns=num_label_columns,
|
||||
optimizer=optimizer,
|
||||
weight_collection_name="DNN",
|
||||
gradient_clip_norm=gradient_clip_norm)
|
||||
self._hidden_units = hidden_units
|
||||
self._activation_fn = activation_fn
|
||||
self._dropout = dropout
|
||||
self._config = config
|
||||
|
||||
def _add_hidden_layer_summary(self, value, tag):
|
||||
# TODO(zakaria): Move this code to tf.learn and add test.
|
||||
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||
nn.zero_fraction(value))
|
||||
logging_ops.histogram_summary("%s:activation" % tag, value)
|
||||
|
||||
def build_model(self, features, feature_columns, is_training):
|
||||
"""See base class."""
|
||||
features = self._get_feature_dict(features)
|
||||
self._feature_columns = feature_columns
|
||||
|
||||
net = layers.input_from_feature_columns(
|
||||
features,
|
||||
self._get_feature_columns(),
|
||||
weight_collections=[self._weight_collection_name])
|
||||
for layer_id, num_hidden_units in enumerate(self._hidden_units):
|
||||
with variable_scope.variable_op_scope(
|
||||
[net], "hiddenlayer_%d" % layer_id,
|
||||
partitioner=partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=self._config.num_ps_replicas)) as scope:
|
||||
net = layers.fully_connected(
|
||||
net,
|
||||
num_hidden_units,
|
||||
activation_fn=self._activation_fn,
|
||||
variables_collections=[self._weight_collection_name],
|
||||
scope=scope)
|
||||
if self._dropout is not None and is_training:
|
||||
net = layers.dropout(
|
||||
net,
|
||||
keep_prob=(1.0 - self._dropout))
|
||||
self._add_hidden_layer_summary(net, scope.name)
|
||||
with variable_scope.variable_op_scope(
|
||||
[net], "dnn_logits",
|
||||
partitioner=partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=self._config.num_ps_replicas)) as scope:
|
||||
logits = layers.fully_connected(
|
||||
net,
|
||||
self._num_label_columns,
|
||||
activation_fn=None,
|
||||
variables_collections=[self._weight_collection_name],
|
||||
scope=scope)
|
||||
self._add_hidden_layer_summary(logits, "dnn_logits")
|
||||
return logits
|
||||
|
||||
def _get_optimizer(self):
|
||||
if self._optimizer is None:
|
||||
self._optimizer = "Adagrad"
|
||||
if isinstance(self._optimizer, six.string_types):
|
||||
self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
|
||||
learning_rate=0.05)
|
||||
return self._optimizer
|
||||
|
||||
|
||||
# TODO(ispir): Increase test coverage
|
||||
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
"""An estimator for TensorFlow Linear and DNN joined training models.
|
||||
@ -110,6 +341,21 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
"""
|
||||
super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir,
|
||||
config=config)
|
||||
|
||||
self._linear_model = _LinearComposableModel(
|
||||
num_label_columns=target_column.num_label_columns,
|
||||
optimizer=linear_optimizer,
|
||||
gradient_clip_norm=gradient_clip_norm)
|
||||
|
||||
self._dnn_model = _DNNComposableModel(
|
||||
num_label_columns=target_column.num_label_columns,
|
||||
hidden_units=dnn_hidden_units,
|
||||
optimizer=dnn_optimizer,
|
||||
activation_fn=dnn_activation_fn,
|
||||
dropout=dnn_dropout,
|
||||
gradient_clip_norm=gradient_clip_norm,
|
||||
config=self._config) if dnn_hidden_units else None
|
||||
|
||||
self._linear_feature_columns = linear_feature_columns
|
||||
self._linear_optimizer = linear_optimizer
|
||||
self._dnn_feature_columns = dnn_feature_columns
|
||||
@ -175,22 +421,11 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
loss = self._loss(logits, targets, features)
|
||||
logging_ops.scalar_summary("loss", loss)
|
||||
|
||||
linear_vars = self._get_linear_vars()
|
||||
dnn_vars = self._get_dnn_vars()
|
||||
grads = gradients.gradients(loss, dnn_vars + linear_vars)
|
||||
if self._gradient_clip_norm:
|
||||
grads, _ = clip_ops.clip_by_global_norm(grads,
|
||||
self._gradient_clip_norm)
|
||||
linear_train_step = self._linear_model.get_train_step(loss)
|
||||
dnn_train_step = (self._dnn_model.get_train_step(loss)
|
||||
if self._dnn_model else [])
|
||||
|
||||
dnn_grads = grads[0:len(dnn_vars)]
|
||||
linear_grads = grads[len(dnn_vars):]
|
||||
|
||||
train_ops = self._get_linear_training_ops(
|
||||
linear_grads, linear_vars) + self._get_dnn_training_ops(dnn_grads,
|
||||
dnn_vars)
|
||||
|
||||
train_step = control_flow_ops.group(*train_ops, name="combined_training_op")
|
||||
with ops.control_dependencies([train_step]):
|
||||
with ops.control_dependencies(linear_train_step + dnn_train_step):
|
||||
with ops.get_default_graph().colocate_with(global_step):
|
||||
return state_ops.assign_add(global_step, 1).op, loss
|
||||
|
||||
@ -232,54 +467,13 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
feature_column_ops.check_feature_columns(self._dnn_feature_columns)
|
||||
return sorted(set(self._dnn_feature_columns), key=lambda x: x.key)
|
||||
|
||||
def _dnn_logits(self, features, is_training=False):
|
||||
net = layers.input_from_feature_columns(
|
||||
features,
|
||||
self._get_dnn_feature_columns(),
|
||||
weight_collections=[self._dnn_weight_collection])
|
||||
for layer_id, num_hidden_units in enumerate(self._dnn_hidden_units):
|
||||
with variable_scope.variable_op_scope(
|
||||
[net], "hiddenlayer_%d" % layer_id,
|
||||
partitioner=partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=self._config.num_ps_replicas)) as scope:
|
||||
net = layers.fully_connected(
|
||||
net,
|
||||
num_hidden_units,
|
||||
activation_fn=self._dnn_activation_fn,
|
||||
variables_collections=[self._dnn_weight_collection],
|
||||
scope=scope)
|
||||
if self._dnn_dropout is not None and is_training:
|
||||
net = layers.dropout(
|
||||
net,
|
||||
keep_prob=(1.0 - self._dnn_dropout))
|
||||
self._add_hidden_layer_summary(net, scope.name)
|
||||
with variable_scope.variable_op_scope(
|
||||
[net], "dnn_logit",
|
||||
partitioner=partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=self._config.num_ps_replicas)) as scope:
|
||||
logit = layers.fully_connected(
|
||||
net,
|
||||
self._target_column.num_label_columns,
|
||||
activation_fn=None,
|
||||
variables_collections=[self._dnn_weight_collection],
|
||||
scope=scope)
|
||||
self._add_hidden_layer_summary(logit, "dnn_logit")
|
||||
return logit
|
||||
def _dnn_logits(self, features, is_training):
|
||||
return self._dnn_model.build_model(
|
||||
features, self._dnn_feature_columns, is_training)
|
||||
|
||||
def _add_hidden_layer_summary(self, value, tag):
|
||||
# TODO(zakaria): Move this code to tf.learn and add test.
|
||||
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||
nn.zero_fraction(value))
|
||||
logging_ops.histogram_summary("%s:activation" % tag, value)
|
||||
|
||||
def _linear_logits(self, features):
|
||||
logits, _, _ = layers.weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=self._get_linear_feature_columns(),
|
||||
num_outputs=self._target_column.num_label_columns,
|
||||
weight_collections=[self._linear_weight_collection],
|
||||
name="linear")
|
||||
return logits
|
||||
def _linear_logits(self, features, is_training):
|
||||
return self._linear_model.build_model(
|
||||
features, self._linear_feature_columns, is_training)
|
||||
|
||||
def _get_feature_dict(self, features):
|
||||
if isinstance(features, dict):
|
||||
@ -318,12 +512,12 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
|
||||
features = self._get_feature_dict(features)
|
||||
if linear_feature_columns and dnn_feature_columns:
|
||||
logits = (self._linear_logits(features) +
|
||||
self._dnn_logits(features, is_training=is_training))
|
||||
logits = (self._linear_logits(features, is_training) +
|
||||
self._dnn_logits(features, is_training))
|
||||
elif dnn_feature_columns:
|
||||
logits = self._dnn_logits(features, is_training=is_training)
|
||||
logits = self._dnn_logits(features, is_training)
|
||||
else:
|
||||
logits = self._linear_logits(features)
|
||||
logits = self._linear_logits(features, is_training)
|
||||
|
||||
if self._enable_centered_bias:
|
||||
return nn.bias_add(logits, self._centered_bias())
|
||||
|
@ -250,13 +250,22 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
name=None):
|
||||
"""Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
Args:
|
||||
x: features.
|
||||
y: targets.
|
||||
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
feed_fn: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
batch_size: minibatch size to use on the input, defaults to first
|
||||
@ -290,11 +299,14 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
if metrics is not None and not isinstance(metrics, dict):
|
||||
raise ValueError('Metrics argument should be None or dict. '
|
||||
'Got %s.' % metrics)
|
||||
return self._evaluate_model(input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
metrics=metrics,
|
||||
name=name)
|
||||
eval_results, global_step = self._evaluate_model(input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
metrics=metrics,
|
||||
name=name)
|
||||
if eval_results is not None:
|
||||
eval_results.update({'global_step': global_step})
|
||||
return eval_results
|
||||
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, outputs=None):
|
||||
"""Returns predictions for given features.
|
||||
@ -477,15 +489,16 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
# Add default monitors.
|
||||
if monitors is None:
|
||||
monitors = []
|
||||
monitors += monitors_lib.get_default_monitors(
|
||||
loss_op=loss_op,
|
||||
summary_op=logging_ops.get_summary_op(),
|
||||
save_summary_steps=self._config.save_summary_steps,
|
||||
summary_writer=graph_actions.get_summary_writer(self._model_dir))
|
||||
|
||||
is_chief = self._config.task == 0
|
||||
if not is_chief:
|
||||
# Run monitors only on chief.
|
||||
|
||||
if is_chief:
|
||||
monitors += monitors_lib.get_default_monitors(
|
||||
loss_op=loss_op,
|
||||
summary_op=logging_ops.get_summary_op(),
|
||||
save_summary_steps=self._config.save_summary_steps,
|
||||
summary_writer=graph_actions.get_summary_writer(self._model_dir))
|
||||
else:
|
||||
monitors = []
|
||||
|
||||
# Setup monitors.
|
||||
@ -545,7 +558,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
# TODO(wicke): Remove this once Model and associated code are gone.
|
||||
if (hasattr(self._config, 'execution_mode') and
|
||||
self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
|
||||
return
|
||||
return None, None
|
||||
|
||||
# Check that model has been trained.
|
||||
checkpoint_path = self._model_dir
|
||||
@ -564,7 +577,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
self._check_inputs(features, targets)
|
||||
eval_dict = self._get_eval_ops(features, targets, metrics)
|
||||
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
|
||||
eval_results, _ = graph_actions.evaluate(
|
||||
eval_results, current_global_step = graph_actions.evaluate(
|
||||
graph=g,
|
||||
output_dir=eval_dir,
|
||||
checkpoint_path=checkpoint_path,
|
||||
@ -574,7 +587,8 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
supervisor_master=self._config.master,
|
||||
feed_fn=feed_fn,
|
||||
max_steps=steps)
|
||||
return eval_results
|
||||
|
||||
return eval_results, current_global_step
|
||||
|
||||
def _get_features_from_input_fn(self, input_fn):
|
||||
result = input_fn()
|
||||
|
@ -240,6 +240,8 @@ class EstimatorTest(tf.test.TestCase):
|
||||
predictions = est.predict(x=boston.data)
|
||||
other_score = _sklearn.mean_squared_error(predictions, boston.target)
|
||||
self.assertAllClose(other_score, scores['MSE'])
|
||||
self.assertTrue('global_step' in scores)
|
||||
self.assertEqual(scores['global_step'], 100)
|
||||
|
||||
def testIrisAll(self):
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
@ -257,6 +259,8 @@ class EstimatorTest(tf.test.TestCase):
|
||||
axis=1))
|
||||
other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
|
||||
self.assertAllClose(other_score, scores['accuracy'])
|
||||
self.assertTrue('global_step' in scores)
|
||||
self.assertEqual(scores['global_step'], 100)
|
||||
|
||||
def testIrisInputFn(self):
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver
|
||||
@ -520,3 +521,35 @@ class GraphDump(BaseMonitor):
|
||||
else:
|
||||
matched.append(key)
|
||||
return matched, non_matched
|
||||
|
||||
|
||||
class ExportMonitor(EveryN):
|
||||
"""Monitor that exports Estimator every N steps."""
|
||||
|
||||
def __init__(self, every_n_steps, export_dir, exports_to_keep=5):
|
||||
"""Initializes ExportMonitor.
|
||||
|
||||
Args:
|
||||
every_n_steps: Run monitor every N steps.
|
||||
export_dir: str, fodler to export.
|
||||
exports_to_keep: int, number of exports to keep.
|
||||
"""
|
||||
super(ExportMonitor, self).__init__(every_n_steps=every_n_steps)
|
||||
self.export_dir = export_dir
|
||||
self.exports_to_keep = exports_to_keep
|
||||
|
||||
def every_n_step_end(self, step, outputs):
|
||||
super(ExportMonitor, self).every_n_step_end(step, outputs)
|
||||
try:
|
||||
export.export_estimator(self._estimator, self.export_dir,
|
||||
exports_to_keep=self.exports_to_keep)
|
||||
except RuntimeError:
|
||||
# Currently we are not syncronized with saving checkpoints, which leads to
|
||||
# runtime errors when we are calling export on the same global step.
|
||||
logging.info("Skipping exporting for the same step. "
|
||||
"Consider exporting less frequently.")
|
||||
|
||||
def end(self):
|
||||
super(ExportMonitor, self).end()
|
||||
export.export_estimator(self._estimator, self.export_dir,
|
||||
exports_to_keep=self.exports_to_keep)
|
||||
|
@ -0,0 +1,78 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for comparison transforms."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import tensorflow_dataframe as df
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.compare import COMPARISON_TRANSFORMS
|
||||
|
||||
NUMPY_ARRAY_SIZE = 100
|
||||
SCALAR = 50
|
||||
|
||||
|
||||
class CompareTestCase(tf.test.TestCase):
|
||||
"""Test class for comparison transforms."""
|
||||
|
||||
@classmethod
|
||||
def add_test_case(cls, fn_name, op):
|
||||
def _test(self):
|
||||
rng = np.arange(-NUMPY_ARRAY_SIZE // 2,
|
||||
NUMPY_ARRAY_SIZE // 2,
|
||||
dtype="float32")
|
||||
|
||||
frame = df.TensorFlowDataFrame.from_numpy(rng,
|
||||
batch_size=len(rng),
|
||||
shuffle=False)
|
||||
|
||||
frame["sqr"] = frame["value"].square()
|
||||
|
||||
self.assertTrue(hasattr(frame["value"], fn_name))
|
||||
|
||||
frame["series_result"] = getattr(frame["value"],
|
||||
fn_name)(frame["sqr"])
|
||||
frame["scalar_result"] = getattr(frame["value"], fn_name)(SCALAR)
|
||||
|
||||
frame_built = frame.build()
|
||||
|
||||
expected_series_tensor = op(frame_built["value"], frame_built["sqr"])
|
||||
actual_series_tensor = frame_built["series_result"]
|
||||
|
||||
expected_scalar_tensor = op(frame_built["value"], SCALAR)
|
||||
actual_scalar_tensor = frame_built["scalar_result"]
|
||||
|
||||
session = tf.Session()
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(sess=session, coord=coord)
|
||||
actual_series, expected_series, actual_scalar, expected_scalar = (
|
||||
session.run([actual_series_tensor, expected_series_tensor,
|
||||
actual_scalar_tensor, expected_scalar_tensor]))
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
np.testing.assert_almost_equal(expected_series, actual_series)
|
||||
np.testing.assert_almost_equal(expected_scalar, actual_scalar)
|
||||
setattr(cls, "test{}".format(op.__name__), _test)
|
||||
|
||||
for ct in COMPARISON_TRANSFORMS:
|
||||
CompareTestCase.add_test_case(*ct)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import csv
|
||||
import math
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
@ -62,13 +63,14 @@ def _assert_df_equals_dict(expected_df, actual_dict):
|
||||
def _make_test_csv():
|
||||
f = tempfile.NamedTemporaryFile(delete=False, mode="w")
|
||||
w = csv.writer(f)
|
||||
w.writerow(["int", "float", "bool"])
|
||||
w.writerow(["int", "float", "bool", "string"])
|
||||
for _ in range(100):
|
||||
intvalue = np.random.randint(-10, 10)
|
||||
floatvalue = np.random.rand()
|
||||
boolvalue = int(np.random.rand() > 0.3)
|
||||
stringvalue = "S: %.4f" % np.random.rand()
|
||||
|
||||
row = [intvalue, floatvalue, boolvalue]
|
||||
row = [intvalue, floatvalue, boolvalue, stringvalue]
|
||||
w.writerow(row)
|
||||
f.close()
|
||||
return f.name
|
||||
@ -77,14 +79,16 @@ def _make_test_csv():
|
||||
def _make_test_csv_sparse():
|
||||
f = tempfile.NamedTemporaryFile(delete=False, mode="w")
|
||||
w = csv.writer(f)
|
||||
w.writerow(["int", "float", "bool"])
|
||||
w.writerow(["int", "float", "bool", "string"])
|
||||
for _ in range(100):
|
||||
# leave columns empty; these will be read as default value (e.g. 0 or NaN)
|
||||
intvalue = np.random.randint(-10, 10) if np.random.rand() > 0.5 else ""
|
||||
floatvalue = np.random.rand() if np.random.rand() > 0.5 else ""
|
||||
boolvalue = int(np.random.rand() > 0.3) if np.random.rand() > 0.5 else ""
|
||||
stringvalue = (("S: %.4f" % np.random.rand())
|
||||
if np.random.rand() > 0.5 else "")
|
||||
|
||||
row = [intvalue, floatvalue, boolvalue]
|
||||
row = [intvalue, floatvalue, boolvalue, stringvalue]
|
||||
w.writerow(row)
|
||||
f.close()
|
||||
return f.name
|
||||
@ -180,7 +184,7 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
|
||||
enqueue_size = 7
|
||||
|
||||
data_path = _make_test_csv()
|
||||
default_values = [0, 0.0, 0]
|
||||
default_values = [0, 0.0, 0, ""]
|
||||
|
||||
pandas_df = pd.read_csv(data_path)
|
||||
tensorflow_df = df.TensorFlowDataFrame.from_csv(
|
||||
@ -200,7 +204,7 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
|
||||
expected_num_batches = (num_epochs * 100) // batch_size
|
||||
|
||||
data_path = _make_test_csv()
|
||||
default_values = [0, 0.0, 0]
|
||||
default_values = [0, 0.0, 0, ""]
|
||||
|
||||
tensorflow_df = df.TensorFlowDataFrame.from_csv(
|
||||
[data_path],
|
||||
@ -221,10 +225,17 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
|
||||
feature_spec = {
|
||||
"int": tf.FixedLenFeature(None, dtypes.int16, np.nan),
|
||||
"float": tf.VarLenFeature(dtypes.float16),
|
||||
"bool": tf.VarLenFeature(dtypes.bool)
|
||||
"bool": tf.VarLenFeature(dtypes.bool),
|
||||
"string": tf.FixedLenFeature(None, dtypes.string, "")
|
||||
}
|
||||
|
||||
pandas_df = pd.read_csv(data_path)
|
||||
pandas_df = pd.read_csv(data_path, dtype={"string": object})
|
||||
# Pandas insanely uses NaN for empty cells in a string column.
|
||||
# And, we can't use Pandas replace() to fix them because nan != nan
|
||||
s = pandas_df["string"]
|
||||
for i in range(0, len(s)):
|
||||
if isinstance(s[i], float) and math.isnan(s[i]):
|
||||
s[i] = ""
|
||||
tensorflow_df = df.TensorFlowDataFrame.from_csv_with_feature_spec(
|
||||
[data_path],
|
||||
batch_size=batch_size,
|
||||
|
@ -20,3 +20,4 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||
from tensorflow.contrib.learn.python.learn.utils.export import export_estimator
|
||||
|
134
tensorflow/contrib/learn/python/learn/utils/export.py
Normal file
134
tensorflow/contrib/learn/python/learn/utils/export.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Export utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.contrib.session_bundle import gc
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
|
||||
|
||||
def _get_first_op_from_collection(collection_name):
|
||||
"""Get first element from the collection."""
|
||||
elements = ops.get_collection(collection_name)
|
||||
if elements is not None:
|
||||
if elements:
|
||||
return elements[0]
|
||||
return None
|
||||
|
||||
|
||||
def _get_saver():
|
||||
"""Lazy init and return saver."""
|
||||
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
|
||||
if saver is not None:
|
||||
if saver:
|
||||
saver = saver[0]
|
||||
else:
|
||||
saver = None
|
||||
if saver is None and variables.all_variables():
|
||||
saver = tf_saver.Saver()
|
||||
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
|
||||
return saver
|
||||
|
||||
|
||||
def _export_graph(graph, saver, checkpoint_path, export_dir,
|
||||
default_graph_signature, named_graph_signatures,
|
||||
exports_to_keep):
|
||||
"""Exports graph via session_bundle, by creating a Session."""
|
||||
with graph.as_default():
|
||||
with tf_session.Session('') as session:
|
||||
session.run(variables.initialize_local_variables())
|
||||
saver.restore(session, checkpoint_path)
|
||||
export = exporter.Exporter(saver)
|
||||
export.init(session.graph.as_graph_def(),
|
||||
default_graph_signature=default_graph_signature,
|
||||
named_graph_signatures=named_graph_signatures)
|
||||
export.export(export_dir, contrib_variables.get_global_step(), session,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
||||
def _generic_signature_fn(examples, unused_features, predictions):
|
||||
"""Creates generic signature from given examples and predictions.
|
||||
|
||||
This is neeed for backward compatibility with default behaviour of
|
||||
export_estimator.
|
||||
|
||||
Args:
|
||||
examples: `Tensor`.
|
||||
unused_features: `dict` of `Tensor`s.
|
||||
predictions: `dict` of `Tensor`s.
|
||||
|
||||
Returns:
|
||||
Tuple of default signature and named signature.
|
||||
"""
|
||||
tensors = {'inputs': examples}
|
||||
if not isinstance(predictions, dict):
|
||||
predictions = {'outputs': predictions}
|
||||
tensors.update(predictions)
|
||||
default_signature = exporter.generic_signature(tensors)
|
||||
return default_signature, {}
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def _default_input_fn(estimator, examples):
|
||||
"""Creates default input parsing using Estimator's feature signatures."""
|
||||
return estimator._get_feature_ops_from_example(examples)
|
||||
|
||||
|
||||
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
|
||||
signature_fn=_generic_signature_fn, default_batch_size=1,
|
||||
exports_to_keep=None):
|
||||
"""Exports inference graph into given dir.
|
||||
|
||||
Args:
|
||||
estimator: Estimator to export
|
||||
export_dir: A string containing a directory to write the exported graph
|
||||
and checkpoints.
|
||||
input_fn: Function that given `Tensor` of `Example` strings, parses it into
|
||||
features that are then passed to the model.
|
||||
signature_fn: Function that given `Tensor` of `Example` strings,
|
||||
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
|
||||
and returns default and named exporting signautres.
|
||||
default_batch_size: Default batch size of the `Example` placeholder.
|
||||
exports_to_keep: Number of exports to keep.
|
||||
"""
|
||||
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
|
||||
with ops.Graph().as_default() as g:
|
||||
contrib_variables.create_global_step(g)
|
||||
examples = array_ops.placeholder(dtype=dtypes.string,
|
||||
shape=[default_batch_size],
|
||||
name='input_example_tensor')
|
||||
features = input_fn(estimator, examples)
|
||||
predictions = estimator._get_predict_ops(features)
|
||||
default_signature, named_graph_signatures = signature_fn(
|
||||
examples, features, predictions)
|
||||
if exports_to_keep is not None:
|
||||
exports_to_keep = gc.largest_export_versions(exports_to_keep)
|
||||
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
|
||||
default_graph_signature=default_signature,
|
||||
named_graph_signatures=named_graph_signatures,
|
||||
exports_to_keep=exports_to_keep)
|
||||
# pylint: enable=protected-access
|
||||
|
50
tensorflow/contrib/learn/python/learn/utils/export_test.py
Normal file
50
tensorflow/contrib/learn/python/learn/utils/export_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for export tools."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib import learn
|
||||
|
||||
|
||||
class ExportTest(tf.test.TestCase):
|
||||
|
||||
def testExportMonitor(self):
|
||||
random.seed(42)
|
||||
x = np.random.rand(1000)
|
||||
y = 2 * x + 3
|
||||
regressor = learn.LinearRegressor()
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_monitor = learn.monitors.ExportMonitor(every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
exports_to_keep=1)
|
||||
regressor.fit(x, y, steps=10,
|
||||
monitors=[export_monitor])
|
||||
self.assertTrue(tf.gfile.Exists(export_dir))
|
||||
self.assertFalse(tf.gfile.Exists(export_dir + '00000000/export'))
|
||||
self.assertTrue(tf.gfile.Exists(export_dir + '00000010/export'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -50,7 +50,10 @@ py_library(
|
||||
],
|
||||
data = [":python/ops/_sdca_ops.so"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":sdca_ops"],
|
||||
deps = [
|
||||
":sdca_ops",
|
||||
"//tensorflow/contrib/lookup:lookup_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -38,40 +38,15 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "resources",
|
||||
srcs = ["resources.cc"],
|
||||
hdrs = ["resources.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@protobuf//:protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "resources_test",
|
||||
size = "small",
|
||||
srcs = ["resources_test.cc"],
|
||||
deps = [
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sdca_ops",
|
||||
srcs = ["sdca_ops.cc"],
|
||||
deps = [
|
||||
":loss_updaters",
|
||||
":resources",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//third_party/eigen3",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@protobuf//:protobuf",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -1,86 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DataByExample::DataByExample(const string& container, const string& solver_uuid)
|
||||
: container_(container), solver_uuid_(solver_uuid) {}
|
||||
|
||||
DataByExample::~DataByExample() {}
|
||||
|
||||
// static
|
||||
DataByExample::EphemeralKey DataByExample::MakeKey(const string& example_id) {
|
||||
return Fingerprint128(example_id);
|
||||
}
|
||||
|
||||
DataByExample::Data DataByExample::Get(const EphemeralKey& key) {
|
||||
mutex_lock l(mu_);
|
||||
return data_by_key_[key];
|
||||
}
|
||||
|
||||
void DataByExample::Set(const EphemeralKey& key, const Data& data) {
|
||||
mutex_lock l(mu_);
|
||||
data_by_key_[key] = data;
|
||||
}
|
||||
|
||||
Status DataByExample::Visit(
|
||||
std::function<void(const Data& data)> visitor) const {
|
||||
struct State {
|
||||
// Snapshoted size of data_by_key_.
|
||||
size_t size;
|
||||
|
||||
// Number of elements visited so far.
|
||||
size_t num_visited = 0;
|
||||
|
||||
// Current element.
|
||||
DataByKey::const_iterator it;
|
||||
};
|
||||
|
||||
auto state = [this] {
|
||||
mutex_lock l(mu_);
|
||||
State result;
|
||||
result.size = data_by_key_.size();
|
||||
result.it = data_by_key_.cbegin();
|
||||
return result;
|
||||
}();
|
||||
|
||||
while (state.num_visited < state.size) {
|
||||
mutex_lock l(mu_);
|
||||
// Since DataByExample is modify-or-append only, a visit will (continue to)
|
||||
// be successful if and only if the size of the backing store hasn't
|
||||
// changed (since the body of this while-loop is under lock).
|
||||
if (data_by_key_.size() != state.size) {
|
||||
return errors::Unavailable("The number of elements for ", solver_uuid_,
|
||||
" has changed which nullifies a visit.");
|
||||
}
|
||||
for (size_t i = 0; i < kVisitChunkSize && state.num_visited < state.size;
|
||||
++i, ++state.num_visited, ++state.it) {
|
||||
visitor(state.it->second);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string DataByExample::DebugString() {
|
||||
return strings::StrCat("DataByExample(", container_, ", ", solver_uuid_, ")");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,108 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Resource for storing per-example data across many sessions. The data is
|
||||
// operated on in a modify or append fashion (data can be modified or added, but
|
||||
// never deleted).
|
||||
//
|
||||
// This class is thread-safe.
|
||||
class DataByExample : public ResourceBase {
|
||||
public:
|
||||
// The container and solver_uuid are only used for debugging purposes.
|
||||
DataByExample(const string& container, const string& solver_uuid);
|
||||
|
||||
virtual ~DataByExample();
|
||||
|
||||
// Platform independent, compact and unique (with very high probability)
|
||||
// representation of an example id. 'Ephemeral' because it shouldn't be put
|
||||
// in persistent storage, as its implementation may change in the future.
|
||||
//
|
||||
// The current probability of at least one collision for 1B example_ids is
|
||||
// approximately 10^-21 (ie 2^60 / 2^129).
|
||||
using EphemeralKey = Fprint128;
|
||||
|
||||
// Makes a key for the supplied example_id, for compact storage.
|
||||
static EphemeralKey MakeKey(const string& example_id);
|
||||
|
||||
struct Data {
|
||||
float dual = 0;
|
||||
float primal_loss = 0;
|
||||
float dual_loss = 0;
|
||||
float example_weight = 0;
|
||||
};
|
||||
|
||||
// Accessor and mutator for the entry at Key. Accessor creates an entry with
|
||||
// default value (default constructed object) if the key is not present and
|
||||
// returns it.
|
||||
Data Get(const EphemeralKey& key) LOCKS_EXCLUDED(mu_);
|
||||
void Set(const EphemeralKey& key, const Data& data) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Visits all elements in this resource. The view of each element (Data) is
|
||||
// atomic, but the entirety of the visit is not (ie the visitor might see
|
||||
// different versions of the Data across elements).
|
||||
//
|
||||
// Returns OK on success or UNAVAILABLE if the number of elements in this
|
||||
// container has changed since the beginning of the visit (in which case the
|
||||
// visit cannot be completed and is aborted early, and computation can be
|
||||
// restarted).
|
||||
Status Visit(std::function<void(const Data& data)> visitor) const
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
|
||||
string DebugString() override;
|
||||
|
||||
private:
|
||||
// Backing container.
|
||||
//
|
||||
// sizeof(EntryPayload) =
|
||||
// sizeof(Key) + sizeof(Data) =
|
||||
// 16 + 16 = 32.
|
||||
//
|
||||
// So on average we use ~51.5 (32 + 19.5) bytes per entry in this table.
|
||||
using EphemeralKeyHasher = Fprint128Hasher;
|
||||
using DataByKey = std::unordered_map<EphemeralKey, Data, EphemeralKeyHasher>;
|
||||
|
||||
// TODO(sibyl-Mooth6ku): Benchmark and/or optimize this.
|
||||
static const size_t kVisitChunkSize = 100;
|
||||
|
||||
const string container_;
|
||||
const string solver_uuid_;
|
||||
|
||||
// TODO(sibyl-Mooth6ku): Come up with a more efficient locking scheme.
|
||||
mutable mutex mu_;
|
||||
DataByKey data_by_key_ GUARDED_BY(mu_);
|
||||
|
||||
friend class DataByExampleTest;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
|
@ -1,184 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Operators for testing convenience (for EQ and NE GUnit macros).
|
||||
bool operator==(const DataByExample::Data& lhs,
|
||||
const DataByExample::Data& rhs) {
|
||||
return lhs.dual == rhs.dual && //
|
||||
lhs.primal_loss == rhs.primal_loss && //
|
||||
lhs.dual_loss == rhs.dual_loss && //
|
||||
lhs.example_weight == rhs.example_weight;
|
||||
}
|
||||
|
||||
bool operator!=(const DataByExample::Data& lhs,
|
||||
const DataByExample::Data& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
class DataByExampleTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
const string solver_uuid = "TheSolver";
|
||||
ASSERT_TRUE(resource_manager_
|
||||
.LookupOrCreate<DataByExample>(
|
||||
container_, solver_uuid, &data_by_example_,
|
||||
[&, this](DataByExample** ret) {
|
||||
*ret = new DataByExample(container_, solver_uuid);
|
||||
return Status::OK();
|
||||
})
|
||||
.ok());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
data_by_example_->Unref();
|
||||
ASSERT_TRUE(resource_manager_.Cleanup(container_).ok());
|
||||
}
|
||||
|
||||
// Accessors and mutators to private members of DataByExample for better
|
||||
// testing.
|
||||
static size_t VisitChunkSize() { return DataByExample::kVisitChunkSize; }
|
||||
void InsertReservedEntryUnlocked() NO_THREAD_SAFETY_ANALYSIS {
|
||||
data_by_example_->data_by_key_[{0, 0}];
|
||||
}
|
||||
|
||||
const string container_ = "TheContainer";
|
||||
ResourceMgr resource_manager_;
|
||||
DataByExample* data_by_example_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_F(DataByExampleTest, MakeKeyIsCollisionResistent) {
|
||||
const DataByExample::EphemeralKey key =
|
||||
DataByExample::MakeKey("TheExampleId");
|
||||
EXPECT_NE(key.low64, key.high64);
|
||||
}
|
||||
|
||||
TEST_F(DataByExampleTest, MakeKeyIsPlatformAgnostic) {
|
||||
// This is one way of enforcing the platform-agnostic nature of
|
||||
// DataByExample::MakeKey. Basically we are checking against exact values and
|
||||
// this test could be running across different platforms.
|
||||
// Note that it is fine for expected values to change in the future, if the
|
||||
// implementation of MakeKey changes (ie this is *not* a frozen test).
|
||||
const DataByExample::EphemeralKey key =
|
||||
DataByExample::MakeKey("TheExampleId");
|
||||
EXPECT_EQ(10492632643343118393ULL, key.low64);
|
||||
EXPECT_EQ(1007244271654873956ULL, key.high64);
|
||||
}
|
||||
|
||||
TEST_F(DataByExampleTest, ElementAccessAndMutation) {
|
||||
const DataByExample::EphemeralKey key1 =
|
||||
DataByExample::MakeKey("TheExampleId1");
|
||||
EXPECT_EQ(DataByExample::Data(), data_by_example_->Get(key1));
|
||||
|
||||
DataByExample::Data data1;
|
||||
data1.dual = 1.0f;
|
||||
data_by_example_->Set(key1, data1);
|
||||
EXPECT_EQ(data1, data_by_example_->Get(key1));
|
||||
|
||||
const DataByExample::EphemeralKey key2 =
|
||||
DataByExample::MakeKey("TheExampleId2");
|
||||
EXPECT_NE(data_by_example_->Get(key1), data_by_example_->Get(key2));
|
||||
}
|
||||
|
||||
TEST_F(DataByExampleTest, VisitEmpty) {
|
||||
size_t num_elements = 0;
|
||||
ASSERT_TRUE(
|
||||
data_by_example_
|
||||
->Visit([&](const DataByExample::Data& data) { ++num_elements; })
|
||||
.ok());
|
||||
EXPECT_EQ(0, num_elements);
|
||||
}
|
||||
|
||||
TEST_F(DataByExampleTest, VisitMany) {
|
||||
const size_t kNumElements = 2 * VisitChunkSize() + 1;
|
||||
for (size_t i = 0; i < kNumElements; ++i) {
|
||||
DataByExample::Data data;
|
||||
data.dual = static_cast<float>(i);
|
||||
data_by_example_->Set(DataByExample::MakeKey(strings::StrCat(i)), data);
|
||||
}
|
||||
size_t num_elements = 0;
|
||||
double total_dual = 0;
|
||||
ASSERT_TRUE(data_by_example_
|
||||
->Visit([&](const DataByExample::Data& data) {
|
||||
++num_elements;
|
||||
total_dual += data.dual;
|
||||
})
|
||||
.ok());
|
||||
EXPECT_EQ(kNumElements, num_elements);
|
||||
EXPECT_DOUBLE_EQ(
|
||||
// 0 + 1 + ... + (N-1) = (N-1)*N/2
|
||||
(kNumElements - 1) * kNumElements / 2.0, total_dual);
|
||||
}
|
||||
|
||||
TEST_F(DataByExampleTest, VisitUnavailable) {
|
||||
// Populate enough entries so that Visiting will be chunked.
|
||||
for (size_t i = 0; i < 2 * VisitChunkSize(); ++i) {
|
||||
data_by_example_->Get(DataByExample::MakeKey(strings::StrCat(i)));
|
||||
}
|
||||
|
||||
struct Condition {
|
||||
mutex mu;
|
||||
bool c GUARDED_BY(mu) = false;
|
||||
condition_variable cv;
|
||||
};
|
||||
auto signal = [](Condition* const condition) {
|
||||
mutex_lock l(condition->mu);
|
||||
condition->c = true;
|
||||
condition->cv.notify_all();
|
||||
};
|
||||
auto wait = [](Condition* const condition) {
|
||||
mutex_lock l(condition->mu);
|
||||
while (!condition->c) {
|
||||
condition->cv.wait(l);
|
||||
}
|
||||
};
|
||||
|
||||
Condition paused_visit; // Signaled after a Visit has paused.
|
||||
Condition updated_data; // Signaled after data has been updated.
|
||||
Condition completed_visit; // Signaled after a Visit has completed.
|
||||
|
||||
thread::ThreadPool thread_pool(Env::Default(), "test", 2 /* num_threads */);
|
||||
Status status;
|
||||
size_t num_visited = 0;
|
||||
thread_pool.Schedule([&] {
|
||||
status = data_by_example_->Visit([&](const DataByExample::Data& unused) {
|
||||
++num_visited;
|
||||
if (num_visited == VisitChunkSize()) {
|
||||
// Safe point to mutate the data structure without a lock below.
|
||||
signal(&paused_visit);
|
||||
wait(&updated_data);
|
||||
}
|
||||
});
|
||||
signal(&completed_visit);
|
||||
});
|
||||
thread_pool.Schedule([&, this] {
|
||||
wait(&paused_visit);
|
||||
InsertReservedEntryUnlocked();
|
||||
signal(&updated_data);
|
||||
});
|
||||
wait(&completed_visit);
|
||||
EXPECT_TRUE(errors::IsUnavailable(status));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -28,14 +28,12 @@ limitations under the License.
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h"
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/squared-loss.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
@ -47,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/sparse/group_iterator.h"
|
||||
@ -596,22 +595,18 @@ class FeaturesAndWeights {
|
||||
};
|
||||
|
||||
Status RunTrainStepsForMiniBatch(
|
||||
const int num_examples, const TTypes<const string>::Vec example_ids,
|
||||
const TTypes<const float>::Vec example_labels,
|
||||
const int num_examples, const TTypes<const float>::Vec example_labels,
|
||||
const TTypes<const float>::Vec example_weights,
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads,
|
||||
const Regularizations& regularizations, const DualLossUpdater& loss_updater,
|
||||
FeaturesAndWeights* const features_and_weights,
|
||||
DataByExample* const data_by_example) {
|
||||
TTypes<float>::Matrix example_state_data) {
|
||||
// Process examples in parallel, in a partitioned fashion.
|
||||
mutex mu;
|
||||
Status train_step_status GUARDED_BY(mu);
|
||||
auto train_step = [&](const int64 begin, const int64 end) {
|
||||
for (int64 example_index = begin; example_index < end; ++example_index) {
|
||||
// Get example id, label, and weight.
|
||||
const DataByExample::EphemeralKey example_key =
|
||||
DataByExample::MakeKey(example_ids(example_index));
|
||||
DataByExample::Data data = data_by_example->Get(example_key);
|
||||
const float dual = example_state_data(example_index, 0);
|
||||
const float example_weight = example_weights(example_index);
|
||||
float example_label = example_labels(example_index);
|
||||
const Status conversion_status =
|
||||
@ -633,24 +628,23 @@ Status RunTrainStepsForMiniBatch(
|
||||
const double primal_loss = loss_updater.ComputePrimalLoss(
|
||||
per_example_data.wx, example_label, example_weight);
|
||||
|
||||
const double dual_loss = loss_updater.ComputeDualLoss(
|
||||
data.dual, example_label, example_weight);
|
||||
const double dual_loss =
|
||||
loss_updater.ComputeDualLoss(dual, example_label, example_weight);
|
||||
|
||||
const double new_dual = loss_updater.ComputeUpdatedDual(
|
||||
example_label, example_weight, data.dual, per_example_data.wx,
|
||||
example_label, example_weight, dual, per_example_data.wx,
|
||||
per_example_data.normalized_squared_norm, primal_loss, dual_loss);
|
||||
|
||||
// Compute new weights.
|
||||
const double bounded_dual_delta = (new_dual - data.dual) * example_weight;
|
||||
const double bounded_dual_delta = (new_dual - dual) * example_weight;
|
||||
features_and_weights->UpdateDeltaWeights(
|
||||
example_index, bounded_dual_delta, regularizations.symmetric_l2());
|
||||
|
||||
// Update example data.
|
||||
data.dual = new_dual;
|
||||
data.primal_loss = primal_loss;
|
||||
data.dual_loss = dual_loss;
|
||||
data.example_weight = example_weight;
|
||||
data_by_example->Set(example_key, data);
|
||||
example_state_data(example_index, 0) = new_dual;
|
||||
example_state_data(example_index, 1) = primal_loss;
|
||||
example_state_data(example_index, 2) = dual_loss;
|
||||
example_state_data(example_index, 3) = example_weight;
|
||||
}
|
||||
};
|
||||
// TODO(sibyl-Aix6ihai): Current multiplier 100000 works well empirically
|
||||
@ -689,30 +683,9 @@ class SdcaSolver : public OpKernel {
|
||||
OP_REQUIRES_OK(context, regularizations_.Initialize(context));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
|
||||
&num_inner_iterations_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Get a handle on a shared container across invocations of this Kernel.
|
||||
// The shared container is intended to maintain state at the example level
|
||||
// across invocations of the kernel on different input data.
|
||||
//
|
||||
// TODO(sibyl-Mooth6ku): Replace this in-Kernel data structure with a first class
|
||||
// citizen mutable Dictionary in tensorflow proper, that we will initialize
|
||||
// and update externally.
|
||||
DataByExample* data_by_example = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->resource_manager()->LookupOrCreate<DataByExample>(
|
||||
container_, solver_uuid_, &data_by_example,
|
||||
[this](DataByExample** ret) {
|
||||
*ret = new DataByExample(container_, solver_uuid_);
|
||||
return Status::OK();
|
||||
}));
|
||||
OP_REQUIRES(
|
||||
context, !data_by_example->RefCountIsOne(),
|
||||
errors::Internal("Expected shared-ownership of data_by_example."));
|
||||
|
||||
const Tensor* example_weights_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("example_weights", &example_weights_t));
|
||||
@ -738,16 +711,19 @@ class SdcaSolver : public OpKernel {
|
||||
"number of example weights (%d).",
|
||||
example_labels.size(), num_examples)));
|
||||
|
||||
const Tensor* example_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("example_ids", &example_ids_t));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(example_ids_t->shape()),
|
||||
errors::InvalidArgument("example_ids should be a vector."));
|
||||
const auto example_ids = example_ids_t->vec<string>();
|
||||
OP_REQUIRES(context, example_labels.size() == num_examples,
|
||||
errors::InvalidArgument(strings::Printf(
|
||||
"The number of example ids (%ld) should match the number "
|
||||
"of example weights (%d).",
|
||||
example_ids.size(), num_examples)));
|
||||
const Tensor* example_state_data_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("example_state_data", &example_state_data_t));
|
||||
TensorShape expected_example_state_shape({num_examples, 4});
|
||||
OP_REQUIRES(
|
||||
context, example_state_data_t->shape() == expected_example_state_shape,
|
||||
errors::InvalidArgument("Expected shape ",
|
||||
expected_example_state_shape.DebugString(),
|
||||
" for example_state_data, got ",
|
||||
example_state_data_t->shape().DebugString()));
|
||||
|
||||
Tensor mutable_example_state_data_t(*example_state_data_t);
|
||||
auto example_state_data = mutable_example_state_data_t.matrix<float>();
|
||||
|
||||
FeaturesAndWeights features_and_weights;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -757,17 +733,15 @@ class SdcaSolver : public OpKernel {
|
||||
|
||||
for (int i = 0; i < num_inner_iterations_; ++i) {
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
RunTrainStepsForMiniBatch(
|
||||
num_examples, example_ids, example_labels, example_weights,
|
||||
*context->device()->tensorflow_cpu_worker_threads(),
|
||||
regularizations_, *loss_updater_, &features_and_weights,
|
||||
data_by_example));
|
||||
context, RunTrainStepsForMiniBatch(
|
||||
num_examples, example_labels, example_weights,
|
||||
*context->device()->tensorflow_cpu_worker_threads(),
|
||||
regularizations_, *loss_updater_, &features_and_weights,
|
||||
example_state_data));
|
||||
}
|
||||
features_and_weights.AddDeltaWeights();
|
||||
|
||||
// TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
|
||||
data_by_example->Unref();
|
||||
context->set_output(0, mutable_example_state_data_t);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -779,8 +753,6 @@ class SdcaSolver : public OpKernel {
|
||||
int64 num_dense_features_;
|
||||
Regularizations regularizations_;
|
||||
int num_inner_iterations_;
|
||||
string container_;
|
||||
string solver_uuid_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
|
||||
|
||||
@ -803,72 +775,26 @@ class SdcaShrinkL1 : public OpKernel {
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
|
||||
|
||||
class SdcaTrainingStats : public OpKernel {
|
||||
class SdcaFprint : public OpKernel {
|
||||
public:
|
||||
explicit SdcaTrainingStats(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
|
||||
explicit SdcaFprint(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& input = ctx->input(0);
|
||||
Tensor* out;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
|
||||
|
||||
const auto in_values = input.flat<string>();
|
||||
auto out_values = out->flat<string>();
|
||||
|
||||
for (int64 i = 0; i < in_values.size(); ++i) {
|
||||
const string& s = in_values(i);
|
||||
Fprint128 fprint = Fingerprint128(s);
|
||||
out_values(i) = strings::StrCat(strings::FpToString(fprint.high64), "-",
|
||||
strings::FpToString(fprint.low64));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DataByExample* data_by_example = nullptr;
|
||||
OP_REQUIRES_OK(context, context->resource_manager()->Lookup<DataByExample>(
|
||||
container_, solver_uuid_, &data_by_example));
|
||||
OP_REQUIRES(
|
||||
context, !data_by_example->RefCountIsOne(),
|
||||
errors::Internal("Expected shared-ownership of data_by_example."));
|
||||
|
||||
double total_primal_loss = 0;
|
||||
double total_dual_loss = 0;
|
||||
double total_example_weight = 0;
|
||||
OP_REQUIRES_OK(context,
|
||||
data_by_example->Visit([&](const DataByExample::Data& data) {
|
||||
total_primal_loss += data.primal_loss;
|
||||
total_dual_loss += data.dual_loss;
|
||||
total_example_weight += data.example_weight;
|
||||
}));
|
||||
|
||||
// TODO(sibyl-Mooth6ku): Think about the most arithmetically stable way of
|
||||
// computing (dual + primal) loss (if it matters).
|
||||
|
||||
{
|
||||
Tensor* tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("primal_loss", {}, &tensor));
|
||||
tensor->scalar<double>()() = total_primal_loss;
|
||||
}
|
||||
|
||||
{
|
||||
Tensor* tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("dual_loss", {}, &tensor));
|
||||
tensor->scalar<double>()() = total_dual_loss;
|
||||
}
|
||||
|
||||
{
|
||||
OP_REQUIRES(
|
||||
context, total_example_weight > 0,
|
||||
errors::FailedPrecondition(
|
||||
"No examples found or all examples have zero weight. Either the "
|
||||
"optimizer was trained with no instances or perhaps there is a "
|
||||
"bug in the training data."));
|
||||
|
||||
Tensor* tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("example_weights", {}, &tensor));
|
||||
tensor->scalar<double>()() = total_example_weight;
|
||||
}
|
||||
|
||||
// TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
|
||||
data_by_example->Unref();
|
||||
}
|
||||
|
||||
private:
|
||||
string container_;
|
||||
string solver_uuid_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("SdcaTrainingStats").Device(DEVICE_CPU),
|
||||
SdcaTrainingStats);
|
||||
REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -25,16 +25,15 @@ REGISTER_OP("SdcaSolver")
|
||||
.Attr("l1: float")
|
||||
.Attr("l2: float")
|
||||
.Attr("num_inner_iterations: int >= 1")
|
||||
.Attr("container: string")
|
||||
.Attr("solver_uuid: string")
|
||||
.Input("sparse_features_indices: num_sparse_features * int64")
|
||||
.Input("sparse_features_values: num_sparse_features * float")
|
||||
.Input("dense_features: num_dense_features * float")
|
||||
.Input("example_weights: float")
|
||||
.Input("example_labels: float")
|
||||
.Input("example_ids: string")
|
||||
.Input("sparse_weights: Ref(num_sparse_features * float)")
|
||||
.Input("dense_weights: Ref(num_dense_features * float)")
|
||||
.Input("example_state_data: float")
|
||||
.Output("example_data_data_out: float")
|
||||
.Doc(R"doc(
|
||||
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
|
||||
L1 + L2 regularization. As global optimization objective is strongly-convex, the
|
||||
@ -54,9 +53,6 @@ num_dense_features: Number of dense feature groups to train on.
|
||||
l1: Symmetric l1 regularization strength.
|
||||
l2: Symmetric l2 regularization strength.
|
||||
num_inner_iterations: Number of iterations per mini-batch.
|
||||
container: Name of the Container that stores data across invocations of this
|
||||
Kernel. Together with SolverUUID form an isolation unit for this solver.
|
||||
solver_uuid: Universally Unique Identifier for this solver.
|
||||
sparse_features_indices: a list of matrices with two columns that contain
|
||||
example_indices, and feature_indices.
|
||||
sparse_features_values: a list of vectors which contains feature value
|
||||
@ -66,12 +62,13 @@ example_weights: a vector which contains the weight associated with each
|
||||
example.
|
||||
example_labels: a vector which contains the label/target associated with each
|
||||
example.
|
||||
example_ids: a vector which contains the unique identifier associated with each
|
||||
example.
|
||||
sparse_weights: a list of vectors where each value is the weight associated with
|
||||
a feature group.
|
||||
dense_weights: a list of vectors where the value is the weight associated with
|
||||
a dense feature group.
|
||||
example_state_data: a list of vectors containing the example state data.
|
||||
example_data_data_out: a list of vectors containing the updated example state
|
||||
data.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SdcaShrinkL1")
|
||||
@ -94,23 +91,14 @@ dense_weights: a list of vectors where the value is the weight associated with
|
||||
a dense feature group.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SdcaTrainingStats")
|
||||
.Attr("container: string")
|
||||
.Attr("solver_uuid: string")
|
||||
.Output("primal_loss: float64")
|
||||
.Output("dual_loss: float64")
|
||||
.Output("example_weights: float64")
|
||||
REGISTER_OP("SdcaFprint")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
.Doc(R"doc(
|
||||
Computes statistics over all examples seen by the optimizer.
|
||||
Computes fingerprints of the input strings.
|
||||
|
||||
container: Name of the Container that stores data across invocations of this
|
||||
Kernel. Together with SolverUUID form an isolation unit for this solver.
|
||||
solver_uuid: Universally Unique Identifier for this solver.
|
||||
primal_loss: total primal loss of all examples seen by the optimizer.
|
||||
dual_loss: total dual loss of all examples seen by the optimizer.
|
||||
example_weights: total example weights of all examples seen by the optimizer
|
||||
(guaranteed to be positive; otherwise returns FAILED_PRECONDITION as it
|
||||
probably indicates a bug in the training data).
|
||||
input: strings to compute fingerprints on.
|
||||
output: the computed fingerprints.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,7 @@ import uuid
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _sdca_ops
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -138,37 +139,6 @@ class SdcaOptimizerTest(TensorFlowTestCase):
|
||||
intra_op_parallelism_threads=1)
|
||||
return self.test_session(use_gpu=False, config=config)
|
||||
|
||||
# The following tests, check that operations raise errors when certain
|
||||
# preconditions on the input data are not satisfied. These errors are raised
|
||||
# regardless of the loss type.
|
||||
def testNoWeightedExamples(self):
|
||||
# Setup test data with 1 positive, and 1 negative example.
|
||||
example_protos = [
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
]
|
||||
# Zeroed out example weights.
|
||||
example_weights = [0.0, 0.0]
|
||||
with self._single_threaded_test_session():
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(1, 1)
|
||||
options = dict(symmetric_l2_regularization=1,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss')
|
||||
|
||||
lr = SdcaModel(CONTAINER, examples, variables, options)
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
|
||||
lr.minimize().run()
|
||||
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
|
||||
with self.assertRaisesOpError(
|
||||
'No examples found or all examples have zero weight.'):
|
||||
lr.approximate_duality_gap().eval()
|
||||
|
||||
|
||||
class SdcaWithLogisticLossTest(SdcaOptimizerTest):
|
||||
"""SDCA optimizer test class for logistic loss."""
|
||||
@ -815,5 +785,18 @@ class SdcaWithHingeLossTest(SdcaOptimizerTest):
|
||||
self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)
|
||||
self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)
|
||||
|
||||
|
||||
class SdcaFprintTest(TensorFlowTestCase):
|
||||
"""Tests for the SdcaFprint op."""
|
||||
|
||||
def testFprint(self):
|
||||
with self.test_session():
|
||||
in_data = tf.constant(['abc', 'very looooooong string', 'def'])
|
||||
out_data = _sdca_ops.sdca_fprint(in_data)
|
||||
self.assertAllEqual([b'a085f09013029e45-3980b2afd2126c04',
|
||||
b'bc5a254df959f26c-512e479a50910f9f',
|
||||
b'79999cd817a03f12-085f182230e03022'],
|
||||
out_data.eval())
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -17,11 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import uuid
|
||||
|
||||
from six.moves import range # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.linear_optimizer.ops import gen_sdca_ops
|
||||
from tensorflow.contrib.lookup import lookup_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.load_library import load_op_library
|
||||
@ -106,10 +105,11 @@ class SdcaModel(object):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, container, examples, variables, options):
|
||||
def __init__(self, container, examples, variables, options): # pylint: disable=unused-argument
|
||||
"""Create a new sdca optimizer."""
|
||||
# TODO(andreasst): get rid of obsolete container parameter
|
||||
|
||||
if not container or not examples or not variables or not options:
|
||||
if not examples or not variables or not options:
|
||||
raise ValueError('All arguments must be specified.')
|
||||
|
||||
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss')
|
||||
@ -136,12 +136,12 @@ class SdcaModel(object):
|
||||
raise ValueError('%s should be non-negative. Found (%f)' %
|
||||
(name, value))
|
||||
|
||||
self._container = container
|
||||
self._examples = examples
|
||||
self._variables = variables
|
||||
self._options = options
|
||||
self._solver_uuid = uuid.uuid4().hex
|
||||
self._create_slots()
|
||||
self._hashtable = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32,
|
||||
[0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
def _symmetric_l2_regularization(self):
|
||||
# Algorithmic requirement (for now) is to have minimal l2 of 1.0
|
||||
@ -264,19 +264,23 @@ class SdcaModel(object):
|
||||
sparse_features_indices.append(convert_to_tensor(sf.indices))
|
||||
sparse_features_values.append(convert_to_tensor(sf.values))
|
||||
|
||||
step_op = _sdca_ops.sdca_solver(
|
||||
example_ids_hashed = _sdca_ops.sdca_fprint(convert_to_tensor(
|
||||
self._examples['example_ids']))
|
||||
example_state_data = self._hashtable.lookup(example_ids_hashed)
|
||||
|
||||
example_state_data_updated = _sdca_ops.sdca_solver(
|
||||
sparse_features_indices,
|
||||
sparse_features_values,
|
||||
self._convert_n_to_tensor(self._examples['dense_features']),
|
||||
convert_to_tensor(self._examples['example_weights']),
|
||||
convert_to_tensor(self._examples['example_labels']),
|
||||
convert_to_tensor(self._examples['example_ids']),
|
||||
self._convert_n_to_tensor(
|
||||
self._slots['unshrinked_sparse_features_weights'],
|
||||
as_ref=True),
|
||||
self._convert_n_to_tensor(
|
||||
self._slots['unshrinked_dense_features_weights'],
|
||||
as_ref=True),
|
||||
example_state_data,
|
||||
l1=self._options['symmetric_l1_regularization'],
|
||||
l2=self._symmetric_l2_regularization(),
|
||||
# TODO(sibyl-Aix6ihai): Provide empirical evidence for this. It is better
|
||||
@ -286,17 +290,17 @@ class SdcaModel(object):
|
||||
# reuse old samples than train on new samples.
|
||||
# See: http://arxiv.org/abs/1602.02136.
|
||||
num_inner_iterations=2,
|
||||
loss_type=self._options['loss_type'],
|
||||
container=self._container,
|
||||
solver_uuid=self._solver_uuid)
|
||||
with ops.control_dependencies([step_op]):
|
||||
assign_ops = []
|
||||
loss_type=self._options['loss_type'])
|
||||
with ops.control_dependencies([example_state_data_updated]):
|
||||
insert_op = self._hashtable.insert(example_ids_hashed,
|
||||
example_state_data_updated)
|
||||
update_ops = [insert_op]
|
||||
for name in ['sparse_features_weights', 'dense_features_weights']:
|
||||
for var, slot_var in zip(self._variables[name],
|
||||
self._slots['unshrinked_' + name]):
|
||||
assign_ops.append(var.assign(slot_var))
|
||||
assign_group = control_flow_ops.group(*assign_ops)
|
||||
with ops.control_dependencies([assign_group]):
|
||||
update_ops.append(var.assign(slot_var))
|
||||
update_group = control_flow_ops.group(*update_ops)
|
||||
with ops.control_dependencies([update_group]):
|
||||
shrink_l1 = _sdca_ops.sdca_shrink_l1(
|
||||
self._convert_n_to_tensor(
|
||||
self._variables['sparse_features_weights'],
|
||||
@ -318,14 +322,17 @@ class SdcaModel(object):
|
||||
An Operation that computes the approximate duality gap over all
|
||||
examples.
|
||||
"""
|
||||
(primal_loss, dual_loss, example_weights) = _sdca_ops.sdca_training_stats(
|
||||
container=self._container,
|
||||
solver_uuid=self._solver_uuid)
|
||||
# Note that example_weights is guaranteed to be positive by
|
||||
# sdca_training_stats so dividing by it is safe.
|
||||
return (primal_loss + dual_loss + math_ops.to_double(self._l1_loss()) +
|
||||
(2.0 * math_ops.to_double(self._l2_loss(
|
||||
self._symmetric_l2_regularization())))) / example_weights
|
||||
_, exported_values = self._hashtable.export()
|
||||
summed_values = math_ops.reduce_sum(exported_values, 0)
|
||||
primal_loss = summed_values[1]
|
||||
dual_loss = summed_values[2]
|
||||
example_weights = summed_values[3]
|
||||
# TODO(andreasst): what about handle examples_weights == 0?
|
||||
return (
|
||||
primal_loss + dual_loss + math_ops.to_float(self._l1_loss()) +
|
||||
(2.0 *
|
||||
math_ops.to_float(self._l2_loss(self._symmetric_l2_regularization())))
|
||||
) / example_weights
|
||||
|
||||
def unregularized_loss(self, examples):
|
||||
"""Add operations to compute the loss (without the regularization loss).
|
||||
|
@ -32,7 +32,7 @@ py_library(
|
||||
deps = [
|
||||
":gc",
|
||||
":manifest_proto_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,7 +57,7 @@ py_library(
|
||||
srcs = ["gc.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,13 +25,15 @@ import os
|
||||
import re
|
||||
import six
|
||||
|
||||
import tensorflow as tf
|
||||
from google.protobuf.any_pb2 import Any
|
||||
|
||||
from tensorflow.contrib.session_bundle import gc
|
||||
from tensorflow.contrib.session_bundle import manifest_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -62,20 +64,20 @@ def gfile_copy_callback(files_to_copy, export_dir_path):
|
||||
basename in the export directory.
|
||||
export_dir_path: Directory to copy the files to.
|
||||
"""
|
||||
tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
|
||||
logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
|
||||
gfile.MakeDirs(export_dir_path)
|
||||
for source_filepath, basename in files_to_copy.items():
|
||||
new_path = os.path.join(
|
||||
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
|
||||
tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
|
||||
logging.info("Copying asset %s to path %s.", source_filepath, new_path)
|
||||
|
||||
if gfile.Exists(new_path):
|
||||
# Guard against being restarted while copying assets, and the file
|
||||
# existing and being in an unknown state.
|
||||
# TODO(b/28676216): Do some file checks before deleting.
|
||||
tf.logging.info("Removing file %s.", new_path)
|
||||
logging.info("Removing file %s.", new_path)
|
||||
gfile.Remove(new_path)
|
||||
tf.gfile.Copy(source_filepath, new_path)
|
||||
gfile.Copy(source_filepath, new_path)
|
||||
|
||||
|
||||
def regression_signature(input_tensor, output_tensor):
|
||||
@ -188,22 +190,22 @@ class Exporter(object):
|
||||
self._has_init = True
|
||||
|
||||
if graph_def or clear_devices:
|
||||
copy = tf.GraphDef()
|
||||
copy = graph_pb2.GraphDef()
|
||||
if graph_def:
|
||||
copy.CopyFrom(graph_def)
|
||||
else:
|
||||
copy.CopyFrom(tf.get_default_graph().as_graph_def())
|
||||
copy.CopyFrom(ops.get_default_graph().as_graph_def())
|
||||
if clear_devices:
|
||||
for node in copy.node:
|
||||
node.device = ""
|
||||
graph_any_buf = Any()
|
||||
graph_any_buf.Pack(copy)
|
||||
tf.add_to_collection(GRAPH_KEY, graph_any_buf)
|
||||
ops.add_to_collection(GRAPH_KEY, graph_any_buf)
|
||||
|
||||
if init_op:
|
||||
if not isinstance(init_op, ops.Operation):
|
||||
raise TypeError("init_op needs to be an Operation: %s" % init_op)
|
||||
tf.add_to_collection(INIT_OP_KEY, init_op)
|
||||
ops.add_to_collection(INIT_OP_KEY, init_op)
|
||||
|
||||
signatures_proto = manifest_pb2.Signatures()
|
||||
if default_graph_signature:
|
||||
@ -212,7 +214,7 @@ class Exporter(object):
|
||||
signatures_proto.named_signatures[signature_name].CopyFrom(signature)
|
||||
signatures_any_buf = Any()
|
||||
signatures_any_buf.Pack(signatures_proto)
|
||||
tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
|
||||
ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
|
||||
|
||||
for filename, tensor in assets:
|
||||
asset = manifest_pb2.AssetFile()
|
||||
@ -220,7 +222,7 @@ class Exporter(object):
|
||||
asset.tensor_binding.tensor_name = tensor.name
|
||||
asset_any_buf = Any()
|
||||
asset_any_buf.Pack(asset)
|
||||
tf.add_to_collection(ASSETS_KEY, asset_any_buf)
|
||||
ops.add_to_collection(ASSETS_KEY, asset_any_buf)
|
||||
|
||||
self._assets_callback = assets_callback
|
||||
|
||||
@ -250,6 +252,10 @@ class Exporter(object):
|
||||
if not self._has_init:
|
||||
raise RuntimeError("init must be called first")
|
||||
|
||||
# Export dir must not end with / or it will break exports to keep. Strip /.
|
||||
if export_dir_base.endswith("/"):
|
||||
export_dir_base = export_dir_base[:-1]
|
||||
|
||||
global_step = training_util.global_step(sess, global_step_tensor)
|
||||
export_dir = os.path.join(
|
||||
compat.as_bytes(export_dir_base),
|
||||
@ -299,11 +305,11 @@ class Exporter(object):
|
||||
|
||||
def _file_path_value(self, path_tensor):
|
||||
"""Returns the filepath value stored in constant `path_tensor`."""
|
||||
if not isinstance(path_tensor, tf.Tensor):
|
||||
if not isinstance(path_tensor, ops.Tensor):
|
||||
raise TypeError("tensor is not a Tensor")
|
||||
if path_tensor.op.type != "Const":
|
||||
raise TypeError("Only constants tensor are supported")
|
||||
if path_tensor.dtype != tf.string:
|
||||
if path_tensor.dtype != dtypes.string:
|
||||
raise TypeError("File paths should be string")
|
||||
str_value = path_tensor.op.get_attr("value").string_val
|
||||
if len(str_value) != 1:
|
||||
|
@ -1127,23 +1127,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "cupti_wrapper_default",
|
||||
srcs = [
|
||||
"platform/default/gpu/cupti_wrapper.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"platform/default/gpu/cupti_wrapper.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
cuda_deps = [
|
||||
":stream_executor",
|
||||
"//third_party/gpus/cuda:cuda_headers",
|
||||
"//third_party/gpus/cuda:cupti_headers",
|
||||
],
|
||||
data = ["//third_party/gpus/cuda:cupti_dsos"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_tracer",
|
||||
srcs = [
|
||||
|
@ -184,27 +184,27 @@ class SimpleRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit SimpleRendezvous() {}
|
||||
|
||||
Status Send(const string& key, const Args& send_args, const Tensor& val,
|
||||
Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
if (is_dead) {
|
||||
return errors::Internal("Send of a dead tensor");
|
||||
}
|
||||
ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(ParseKey(key, &parsed));
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (table_.count(parsed.edge_name) > 0) {
|
||||
string edge_name = parsed.edge_name.ToString();
|
||||
if (table_.count(edge_name) > 0) {
|
||||
return errors::Internal("Send of an already sent tensor");
|
||||
}
|
||||
table_[parsed.edge_name] = val;
|
||||
table_[edge_name] = val;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecvAsync(const string& key, const Args& recv_args,
|
||||
void RecvAsync(const ParsedKey& parsed, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
Tensor tensor;
|
||||
Status status = Status::OK();
|
||||
{
|
||||
string key = parsed.edge_name.ToString();
|
||||
mutex_lock l(mu_);
|
||||
if (table_.count(key) <= 0) {
|
||||
status = errors::Internal("Did not find key ", key);
|
||||
@ -417,7 +417,14 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||
// this node. Don't bother processing the rest of the nodes.
|
||||
return c > 0;
|
||||
}
|
||||
Status s = rendez->Recv(tensor_name, Rendezvous::Args(), &output, &is_dead);
|
||||
|
||||
string full_key = Rendezvous::CreateKey("/cpu:0", 1, "/cpu:1", tensor_name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(full_key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendez->Recv(parsed, Rendezvous::Args(), &output, &is_dead);
|
||||
}
|
||||
if (!s.ok() || is_dead) {
|
||||
return c > 0;
|
||||
}
|
||||
|
@ -43,8 +43,7 @@ std::vector<RegistrationInfo>* MutableRegistry() {
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
void CopyTensor::ViaDMA(const string& edge_name,
|
||||
DeviceContext* send_dev_context,
|
||||
void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
DeviceContext* recv_dev_context, Device* src,
|
||||
Device* dst, const AllocatorAttributes src_alloc_attr,
|
||||
const AllocatorAttributes dst_alloc_attr,
|
||||
|
@ -42,7 +42,7 @@ class CopyTensor {
|
||||
// the type of devices and memory in use, the copy may be performed
|
||||
// synchronously or asynchronously. 'done' will be invoked only
|
||||
// after the copy is actually complete.
|
||||
static void ViaDMA(const string& edge_name, DeviceContext* send_dev_context,
|
||||
static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
DeviceContext* recv_dev_context, Device* src, Device* dst,
|
||||
const AllocatorAttributes src_alloc_attr,
|
||||
const AllocatorAttributes dst_alloc_attr,
|
||||
|
@ -24,13 +24,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
|
||||
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
||||
: name_backing_store_(128) {
|
||||
for (Device* d : devices) {
|
||||
devices_.push_back(d);
|
||||
|
||||
// Register under both the full name and the local name.
|
||||
device_map_[d->name()] = d;
|
||||
device_map_[DeviceNameUtils::LocalName(d->name())] = d;
|
||||
string full_name = d->name();
|
||||
device_map_[CopyToBackingStore(full_name)] = d;
|
||||
|
||||
string lname = DeviceNameUtils::LocalName(d->name());
|
||||
device_map_[CopyToBackingStore(lname)] = d;
|
||||
device_type_counts_[d->device_type()]++;
|
||||
}
|
||||
}
|
||||
@ -39,6 +43,13 @@ DeviceMgr::~DeviceMgr() {
|
||||
for (auto p : devices_) delete p;
|
||||
}
|
||||
|
||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||
int n = s.size();
|
||||
char* space = name_backing_store_.Alloc(n);
|
||||
memcpy(space, s.data(), n);
|
||||
return StringPiece(space, n);
|
||||
}
|
||||
|
||||
void DeviceMgr::ListDeviceAttributes(
|
||||
std::vector<DeviceAttributes>* devices) const {
|
||||
devices->reserve(devices_.size());
|
||||
@ -70,7 +81,7 @@ string DeviceMgr::DeviceMappingString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
|
||||
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
Status s;
|
||||
auto iter = device_map_.find(name);
|
||||
if (iter == device_map_.end()) {
|
||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/core/arena.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -49,7 +51,7 @@ class DeviceMgr {
|
||||
|
||||
// Assigns *device with pointer to Device of the given name.
|
||||
// Accepts either a full device name, or just the replica-local suffix.
|
||||
Status LookupDevice(const string& name, Device** device) const;
|
||||
Status LookupDevice(StringPiece name, Device** device) const;
|
||||
|
||||
// Clears given containers of all devices if 'container' is
|
||||
// non-empty. Otherwise, clears default containers of all devices.
|
||||
@ -60,7 +62,11 @@ class DeviceMgr {
|
||||
private:
|
||||
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
|
||||
DeviceVec devices_;
|
||||
std::unordered_map<string, Device*> device_map_;
|
||||
|
||||
StringPiece CopyToBackingStore(StringPiece s);
|
||||
|
||||
std::unordered_map<StringPiece, Device*, StringPiece::Hasher> device_map_;
|
||||
core::Arena name_backing_store_; // Storage for keys in device_map_
|
||||
std::unordered_map<string, int> device_type_counts_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
||||
|
@ -551,6 +551,7 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
IntraProcessRendezvous* rendez) {
|
||||
Status s;
|
||||
Rendezvous::ParsedKey parsed;
|
||||
// Insert the input tensors into the local rendezvous by their
|
||||
// rendezvous key.
|
||||
for (const auto& input : inputs) {
|
||||
@ -560,7 +561,14 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
"' is not a pre-defined feed!");
|
||||
}
|
||||
const string& input_key = it->second;
|
||||
s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
|
||||
|
||||
s = Rendezvous::ParseKey(input_key, &parsed);
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
return s;
|
||||
@ -578,6 +586,7 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
outputs->resize(output_names.size());
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
// Get the outputs from the rendezvous
|
||||
for (size_t output_offset = 0; output_offset < output_names.size();
|
||||
++output_offset) {
|
||||
@ -591,14 +600,16 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
const string& output_key = it->second;
|
||||
Tensor output_tensor;
|
||||
bool is_dead;
|
||||
|
||||
// Fetch data from the Rendezvous.
|
||||
IntraProcessRendezvous* rendez = run_state->rendez;
|
||||
s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
|
||||
if (is_dead && s.ok()) {
|
||||
s = errors::InvalidArgument("The tensor returned for ",
|
||||
output_names[output_offset],
|
||||
" was not valid.");
|
||||
|
||||
s = Rendezvous::ParseKey(output_key, &parsed);
|
||||
if (s.ok()) {
|
||||
// Fetch data from the Rendezvous.
|
||||
s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead);
|
||||
if (is_dead && s.ok()) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", output_name,
|
||||
" was not valid.");
|
||||
}
|
||||
}
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
|
@ -27,6 +27,8 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
|
||||
void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
|
@ -645,6 +645,43 @@ class ExecutorState {
|
||||
}
|
||||
};
|
||||
|
||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||
// have that many nodes in the ready queue, so we just use a vector and
|
||||
// don't free up memory from the queue as we consume nodes.
|
||||
class TaggedNodeReadyQueue {
|
||||
public:
|
||||
TaggedNodeReadyQueue() : front_index_(0) {}
|
||||
|
||||
void push_back(TaggedNode node) { ready_.push_back(node); }
|
||||
TaggedNode front() const {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
return ready_[front_index_];
|
||||
}
|
||||
void pop_front() {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
front_index_++;
|
||||
if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
|
||||
if (front_index_ == ready_.size()) {
|
||||
ready_.clear();
|
||||
} else {
|
||||
// Lots of unused entries at beginning of vector: move everything down
|
||||
// to start of vector.
|
||||
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
|
||||
}
|
||||
front_index_ = 0;
|
||||
}
|
||||
}
|
||||
bool empty() const { return ready_.empty(); }
|
||||
const TaggedNode* begin() const { return ready_.begin() + front_index_; }
|
||||
const TaggedNode* end() const { return ready_.end(); }
|
||||
|
||||
private:
|
||||
gtl::InlinedVector<TaggedNode, 16> ready_;
|
||||
int front_index_;
|
||||
};
|
||||
|
||||
struct AsyncState;
|
||||
|
||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||
typedef gtl::InlinedVector<Entry, 4> EntryVector;
|
||||
|
||||
@ -767,15 +804,15 @@ class ExecutorState {
|
||||
// "node" just finishes. Takes ownership of "stats". Returns true if
|
||||
// execution has completed.
|
||||
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
|
||||
NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
|
||||
NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready);
|
||||
|
||||
// Call Process() on all nodes in 'inline_ready'.
|
||||
void ProcessInline(const std::deque<TaggedNode>& inline_ready);
|
||||
void ProcessInline(const TaggedNodeReadyQueue& inline_ready);
|
||||
|
||||
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
|
||||
// nodes in 'ready' into 'inline_ready'.
|
||||
void ScheduleReady(const TaggedNodeSeq& ready,
|
||||
std::deque<TaggedNode>* inline_ready);
|
||||
TaggedNodeReadyQueue* inline_ready);
|
||||
|
||||
// Provide debugging output about an outstanding node in the executor.
|
||||
void DumpCompletedNodeState(const int node_id, const Entry* input_vector);
|
||||
@ -905,43 +942,55 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// State kept alive for executing an asynchronous node in another
|
||||
// thread. NOTE: We need to make a copy of p.input,
|
||||
// p.input_device_contexts, and p.input_alloc_attrs for asynchronous
|
||||
// kernels because OpKernelContext methods like input_type(i) needs
|
||||
// the param points to valid input type vector. It's not an issue for
|
||||
// sync kernels because these vectors are kept on the stack.
|
||||
struct ExecutorState::AsyncState {
|
||||
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
|
||||
const NodeItem& _item, Entry* _first_input, NodeExecStats* _stats)
|
||||
: saved_inputs(*p.inputs),
|
||||
saved_input_device_contexts(*p.input_device_contexts),
|
||||
saved_input_alloc_attrs(*p.input_alloc_attrs),
|
||||
params(p),
|
||||
tagged_node(_tagged_node),
|
||||
item(_item),
|
||||
first_input(_first_input),
|
||||
// ParamsButClearingEigenGPUDevice does equivalent of
|
||||
// params.eigen_gpu_device = nullptr;
|
||||
ctx(ParamsButClearingEigenGPUDevice(¶ms), item.num_outputs),
|
||||
stats(_stats) {
|
||||
params.inputs = &saved_inputs;
|
||||
params.input_device_contexts = &saved_input_device_contexts;
|
||||
params.input_alloc_attrs = &saved_input_alloc_attrs;
|
||||
}
|
||||
|
||||
// Helpers to make a copy of 'p' and makes a copy of the input type
|
||||
// vector and the device context vector.
|
||||
//
|
||||
// NOTE: We need to make a copy of p.input for asynchronous kernel
|
||||
// because OpKernelContext methods like input_type(i) needs the param
|
||||
// points to valid input type vector. It's not an issue for sync
|
||||
// kernels because the type vector is kept on the stack.
|
||||
OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
|
||||
OpKernelContext::Params* ret = new OpKernelContext::Params;
|
||||
*ret = p;
|
||||
// Ensure the copy of Params will make a new eigen GPU device if
|
||||
// necessary.
|
||||
ret->eigen_gpu_device = nullptr;
|
||||
ret->inputs = new TensorValueVec(*p.inputs);
|
||||
ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
|
||||
ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
|
||||
return ret;
|
||||
}
|
||||
TensorValueVec saved_inputs;
|
||||
DeviceContextVec saved_input_device_contexts;
|
||||
AllocatorAttributeVec saved_input_alloc_attrs;
|
||||
OpKernelContext::Params params;
|
||||
TaggedNode tagged_node;
|
||||
NodeItem item;
|
||||
Entry* first_input;
|
||||
OpKernelContext ctx;
|
||||
NodeExecStats* stats;
|
||||
|
||||
// Helpers to delete 'p' and copies made by CopyParams.
|
||||
void DeleteParams(OpKernelContext::Params* p) {
|
||||
// No need to delete p->eigen_gpu_device since that is deleted in
|
||||
// p's destructor
|
||||
delete p->inputs;
|
||||
delete p->input_device_contexts;
|
||||
delete p->input_alloc_attrs;
|
||||
delete p;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
private:
|
||||
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
|
||||
OpKernelContext::Params* p) {
|
||||
// Ensure OpKernelContext constructor will make a new eigen GPU device if
|
||||
// necessary.
|
||||
p->eigen_gpu_device = nullptr; // Force allocation
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
const NodeItem* nodes = impl_->nodes_;
|
||||
TaggedNodeSeq ready;
|
||||
std::deque<TaggedNode> inline_ready;
|
||||
TaggedNodeReadyQueue inline_ready;
|
||||
|
||||
// Parameters passed to OpKernel::Compute.
|
||||
TensorValueVec inputs;
|
||||
@ -1059,20 +1108,25 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
AsyncOpKernel* async = item.kernel->AsAsync();
|
||||
DCHECK(async != nullptr);
|
||||
launched_asynchronously = true;
|
||||
auto pcopy = CopyParams(params);
|
||||
auto ctx = new OpKernelContext(pcopy, item.num_outputs);
|
||||
auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy,
|
||||
device]() {
|
||||
AsyncState* state =
|
||||
new AsyncState(params, tagged_node, item, first_input, stats);
|
||||
|
||||
auto done = [this, state]() {
|
||||
|
||||
Device* device = impl_->params_.device;
|
||||
NodeExecStats* stats = state->stats; // Shorthand
|
||||
Entry* first_input = state->first_input; // Shorthand
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << this << " Async kernel done: "
|
||||
<< SummarizeNodeDef(item.node->def());
|
||||
<< SummarizeNodeDef(state->item.node->def());
|
||||
}
|
||||
if (stats_collector_) nodestats::SetOpEnd(stats);
|
||||
EntryVector outputs;
|
||||
Status s = ProcessOutputs(item, ctx, &outputs, stats);
|
||||
if (stats_collector_) nodestats::SetMemory(stats, ctx);
|
||||
Status s = ProcessOutputs(state->item, &state->ctx, &outputs, stats);
|
||||
if (stats_collector_) nodestats::SetMemory(stats, &state->ctx);
|
||||
// Clears inputs.
|
||||
int num_inputs = item.num_inputs;
|
||||
const int num_inputs = state->item.num_inputs;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
(first_input + i)->val = *kEmptyTensor;
|
||||
}
|
||||
@ -1080,31 +1134,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
// add better optional debugging support.
|
||||
if (vlog_ && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(mu_);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_completed(tagged_node.node->id());
|
||||
state->tagged_node.input_frame
|
||||
->GetIteration(state->tagged_node.input_iter)
|
||||
->mark_completed(state->tagged_node.node->id());
|
||||
}
|
||||
TaggedNodeSeq ready;
|
||||
if (s.ok()) {
|
||||
PropagateOutputs(tagged_node, outputs, &ready);
|
||||
PropagateOutputs(state->tagged_node, outputs, &ready);
|
||||
}
|
||||
outputs.clear();
|
||||
if (s.ok() && pcopy->device->RequiresRecordingAccessedTensors()) {
|
||||
if (s.ok() &&
|
||||
state->params.device->RequiresRecordingAccessedTensors()) {
|
||||
// Get the list of all tensors accessed during the execution
|
||||
TensorReferenceVector accessed;
|
||||
ctx->retrieve_accessed_tensors(&accessed);
|
||||
state->ctx.retrieve_accessed_tensors(&accessed);
|
||||
if (stats_collector_)
|
||||
nodestats::SetReferencedTensors(stats, accessed);
|
||||
// callee takes ownership of the vector
|
||||
device->ConsumeListOfAccessedTensors(ctx->op_device_context(),
|
||||
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
|
||||
accessed);
|
||||
}
|
||||
bool completed = NodeDone(s, item.node, ready, stats, nullptr);
|
||||
delete ctx;
|
||||
DeleteParams(pcopy);
|
||||
bool completed = NodeDone(s, state->item.node, ready, stats, nullptr);
|
||||
delete state;
|
||||
if (completed) Finish();
|
||||
};
|
||||
if (stats_collector_) nodestats::SetOpStart(stats);
|
||||
device->ComputeAsync(async, ctx, done);
|
||||
device->ComputeAsync(async, &state->ctx, done);
|
||||
} else {
|
||||
// Synchronous computes.
|
||||
OpKernelContext ctx(¶ms, item.num_outputs);
|
||||
@ -1497,7 +1552,7 @@ void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
|
||||
|
||||
bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
const TaggedNodeSeq& ready, NodeExecStats* stats,
|
||||
std::deque<TaggedNode>* inline_ready) {
|
||||
TaggedNodeReadyQueue* inline_ready) {
|
||||
if (stats_collector_) {
|
||||
nodestats::SetAllEnd(stats);
|
||||
stats_collector_->UpdateCostModelNode(stats, impl_->graph_, node);
|
||||
@ -1542,7 +1597,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
return completed;
|
||||
}
|
||||
|
||||
void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
|
||||
void ExecutorState::ProcessInline(const TaggedNodeReadyQueue& inline_ready) {
|
||||
if (inline_ready.empty()) return;
|
||||
int64 scheduled_usec = 0;
|
||||
if (stats_collector_) {
|
||||
@ -1554,7 +1609,7 @@ void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
|
||||
}
|
||||
|
||||
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
|
||||
std::deque<TaggedNode>* inline_ready) {
|
||||
TaggedNodeReadyQueue* inline_ready) {
|
||||
if (ready.empty()) return;
|
||||
|
||||
int64 scheduled_usec = 0;
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -30,7 +30,7 @@ void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
}
|
||||
|
||||
void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& tensor_name,
|
||||
StringPiece tensor_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
|
||||
|
@ -56,9 +56,9 @@ class GPUDeviceContext : public DeviceContext {
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& edge_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) override;
|
||||
|
||||
void MaintainLifetimeOnStream(
|
||||
const Tensor* t, perftools::gputools::Stream* stream) const override {}
|
||||
|
@ -39,7 +39,6 @@ namespace test {
|
||||
|
||||
Benchmark::Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options, Graph* init) {
|
||||
|
||||
SessionOptions default_options;
|
||||
if (!options) {
|
||||
options = &default_options;
|
||||
@ -136,25 +135,35 @@ void Benchmark::RunWithArgs(
|
||||
args.runner = [this](std::function<void()> closure) {
|
||||
pool_->Schedule(closure);
|
||||
};
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
static const int kWarmupRuns = 3;
|
||||
for (int i = 0; i < kWarmupRuns; ++i) {
|
||||
for (const auto& p : in) {
|
||||
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : out) {
|
||||
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(device_->Sync());
|
||||
VLOG(3) << kWarmupRuns << " warmup runs done.";
|
||||
|
||||
testing::StartTiming();
|
||||
while (iters-- > 0) {
|
||||
for (const auto& p : in) {
|
||||
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : out) {
|
||||
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,19 +36,17 @@ IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
|
||||
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
|
||||
|
||||
Status IntraProcessRendezvous::Send(const string& key,
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key;
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(key, args, val, is_dead);
|
||||
return local_->Send(parsed, args, val, is_dead);
|
||||
}
|
||||
|
||||
Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
|
||||
@ -111,24 +109,17 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
done);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const string& key,
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, false /*!is_src*/, &parsed);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
|
||||
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(key, recv_args, [this, parsed, done](
|
||||
const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
local_->RecvAsync(parsed, recv_args, [this, parsed, done](
|
||||
const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
Status s = status;
|
||||
Tensor* out = new Tensor;
|
||||
StatusCallback final_callback = [done, send_args, recv_args, out,
|
||||
|
@ -43,14 +43,14 @@ class IntraProcessRendezvous : public Rendezvous {
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
Status Send(const string& key, const Rendezvous::Args& args,
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const string& key, const Rendezvous::Args& args,
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
|
@ -60,23 +60,25 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
|
||||
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) {
|
||||
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
|
||||
rendez->RecvLocalAsync(
|
||||
key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v,
|
||||
bool dead) {
|
||||
parsed, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v,
|
||||
bool dead) {
|
||||
rendez->Unref();
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
});
|
||||
}
|
||||
|
||||
Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
|
||||
Status BaseRendezvousMgr::RecvLocal(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvLocalAsync(step_id, key,
|
||||
RecvLocalAsync(step_id, parsed,
|
||||
[val, is_dead, &ret, &n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
@ -140,38 +142,35 @@ static bool IsLocalDevice(const WorkerEnv& worker,
|
||||
return device_name.starts_with(worker.worker_name);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::Send(const string& key,
|
||||
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
if (!IsLocalDevice(*env_, parsed.src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(key, args, val, is_dead);
|
||||
return local_->Send(parsed, args, val, is_dead);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed) {
|
||||
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
|
||||
bool is_src) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
|
||||
if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
if (is_src && !IsLocalDevice(*env_, parsed.src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
if (!is_src && !IsLocalDevice(*env_, parsed.dst_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -233,13 +232,11 @@ bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
|
||||
return DeviceNameUtils::IsSameAddressSpace(src, dst);
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, false /*!is_src*/, &parsed);
|
||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
|
||||
Status s = ValidateDevices(parsed, false /*!is_src*/);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
@ -247,12 +244,13 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
|
||||
// Are src and dst in the same worker?
|
||||
if (IsSameWorker(parsed.src, parsed.dst)) {
|
||||
Rendezvous::ParsedKey parsed_copy = parsed;
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(
|
||||
key, recv_args, [this, parsed, done](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
parsed_copy, recv_args,
|
||||
[this, parsed_copy, done](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
|
||||
Status s = status;
|
||||
Tensor* out = new Tensor;
|
||||
StatusCallback final_callback = [done, send_args, recv_args, out,
|
||||
@ -262,27 +260,26 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
};
|
||||
|
||||
if (s.ok()) {
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
|
||||
final_callback);
|
||||
SameWorkerRecvDone(parsed_copy, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(s);
|
||||
}
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
RecvFromRemoteAsync(key, parsed, recv_args, done);
|
||||
RecvFromRemoteAsync(parsed, recv_args, std::move(done));
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
|
||||
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
|
||||
DoneCallback done) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, true /* is_src */, &parsed);
|
||||
Status s = ValidateDevices(parsed, true /* is_src */);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), Args(), Tensor(), false);
|
||||
return;
|
||||
}
|
||||
local_->RecvAsync(key, Args(), done);
|
||||
local_->RecvAsync(parsed, Args(), std::move(done));
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
|
@ -68,12 +68,12 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
void RecvLocalAsync(int64 step_id, const string& key,
|
||||
void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) override;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) override;
|
||||
Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) override;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
@ -116,14 +116,14 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
Status Send(const string& key, const Rendezvous::Args& args,
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const string& key, const Rendezvous::Args& args,
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
@ -134,15 +134,14 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
// network. In either case it needs to retrieve a locally buffered
|
||||
// value from local_, and give it to its caller.
|
||||
//
|
||||
// Runs "done" as soon as the tensor for "key" is available or an error
|
||||
// Runs "done" as soon as the tensor for "parsed" is available or an error
|
||||
// is detected.
|
||||
//
|
||||
// REQUIRES: "key" is one that will be Saved into the local rendezvous.
|
||||
void RecvLocalAsync(const string& key, DoneCallback done);
|
||||
// REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
|
||||
void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
|
||||
|
||||
protected:
|
||||
virtual void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
@ -174,11 +173,10 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
// Active outstanding RecvTensor calls.
|
||||
std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
|
||||
|
||||
// Parses "key" into "parsed". If "is_src" is true, checks that the
|
||||
// rendezvous key's source is in this process. If "is_src" is false,
|
||||
// checks that the rendezvous key's destination is in this process.
|
||||
Status ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed);
|
||||
// If "is_src" is true, checks that the rendezvous key "parsed"'s
|
||||
// source is in this process. If "is_src" is false, checks that the
|
||||
// rendezvous key "parsed"'s destination is in this process.
|
||||
Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
|
||||
|
||||
// Callback handling the case when a rendezvous has been
|
||||
// accomplished in local_ and the consumer is local to this process.
|
||||
|
@ -33,7 +33,7 @@ void CallOptions::StartCancel() {
|
||||
|
||||
void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
|
||||
mutex_lock l(mu_);
|
||||
cancel_func_ = cancel_func;
|
||||
cancel_func_ = std::move(cancel_func);
|
||||
}
|
||||
|
||||
void CallOptions::ClearCancelCallback() {
|
||||
|
@ -127,10 +127,15 @@ float V(const Tensor& tensor) {
|
||||
|
||||
static uint64 kIncarnation = 1; // Uses in following tests.
|
||||
|
||||
string Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
return Rendezvous::CreateKey(sender, incarnation, receiver, name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
Rendezvous::ParsedKey result;
|
||||
CHECK(
|
||||
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
|
||||
name, FrameAndIter(0, 0)),
|
||||
&result)
|
||||
.ok());
|
||||
return result;
|
||||
}
|
||||
|
||||
#define ALICE "/job:j/replica:0/task:0/cpu:0"
|
||||
|
@ -306,10 +306,15 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
|
||||
// Sends values specified by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (const auto& p : in) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
|
||||
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
item->Unref();
|
||||
@ -337,7 +342,10 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
LogMemory::RecordStep(args.step_id, handle);
|
||||
}
|
||||
thread::ThreadPool* pool = worker_env_->compute_pool;
|
||||
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
using namespace std::placeholders;
|
||||
// Line below is equivalent to this code, but does one less indirect call:
|
||||
// args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
|
||||
for (const auto& unit : item->units) {
|
||||
unit.root->RunAsync(args, barrier->Get());
|
||||
}
|
||||
@ -347,11 +355,15 @@ void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
|
||||
StatusCallback done, Status s) {
|
||||
if (s.ok()) {
|
||||
// Receives values requested by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
Tensor* val = &p.second;
|
||||
bool is_dead = false;
|
||||
s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
|
||||
s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
|
||||
}
|
||||
if (is_dead) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
|
@ -57,12 +57,13 @@ class RendezvousMgrInterface {
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
virtual void RecvLocalAsync(int64 step_id, const string& key,
|
||||
virtual void RecvLocalAsync(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) = 0;
|
||||
virtual Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) = 0;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
|
@ -97,49 +97,66 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
req_copy->set_dma_ok(false);
|
||||
}
|
||||
// Type-specialized logging for this method.
|
||||
StatusCallback logging_callback = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
int64 bytes = response->tensor().ByteSize();
|
||||
int64 send_start_usec = start_usec;
|
||||
// If a send start time was reported by the other side, use
|
||||
// that instead. Maybe we should mark the display if we're using
|
||||
// our local time instead of the remote start time?
|
||||
if (response->send_start_micros()) {
|
||||
// send_start_micros is the timestamp taken when the remote
|
||||
// machine began to send the RecvTensor response.
|
||||
// Due to clock skew between source and dest machines, it is
|
||||
// possible that send_start_micros can be larger than end_usec or
|
||||
// less than start_usec.
|
||||
// To respect causality, we enforce the invariants that the RecvTensor
|
||||
// response can not have been sent before the RecvTensor request, and
|
||||
// must have been sent before it was received.
|
||||
send_start_usec = std::max(start_usec, response->send_start_micros());
|
||||
send_start_usec = std::min(send_start_usec, end_usec - 1);
|
||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||
StatusCallback wrapper_done;
|
||||
const StatusCallback* cb_to_use;
|
||||
if (!logging_active && req_copy == nullptr) {
|
||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
||||
} else if (!logging_active) {
|
||||
wrapper_done = [req_copy, done](Status s) {
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
} else {
|
||||
wrapper_done = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
int64 bytes = response->tensor().ByteSize();
|
||||
int64 send_start_usec = start_usec;
|
||||
// If a send start time was reported by the other side, use
|
||||
// that instead. Maybe we should mark the display if we're using
|
||||
// our local time instead of the remote start time?
|
||||
if (response->send_start_micros()) {
|
||||
// send_start_micros is the timestamp taken when the
|
||||
// remote machine began to send the RecvTensor response.
|
||||
// Due to clock skew between source and dest machines, it
|
||||
// is possible that send_start_micros can be larger than
|
||||
// end_usec or less than start_usec.
|
||||
//
|
||||
// To respect causality, we enforce the invariants that
|
||||
// the RecvTensor response can not have been sent before
|
||||
// the RecvTensor request, and must have been sent before
|
||||
// it was received.
|
||||
send_start_usec =
|
||||
std::max(start_usec, response->send_start_micros());
|
||||
send_start_usec = std::min(send_start_usec, end_usec - 1);
|
||||
}
|
||||
const string& key = request->rendezvous_key();
|
||||
std::vector<string> key_parts = str_util::Split(key, ';');
|
||||
if (key_parts.size() != 5) {
|
||||
LOG(WARNING) << "Bad key: " << key;
|
||||
} else {
|
||||
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
|
||||
key_parts[3], // tensor name
|
||||
key_parts[0], // src_device
|
||||
key_parts[2], // dst_device
|
||||
bytes);
|
||||
}
|
||||
}
|
||||
const string& key = request->rendezvous_key();
|
||||
std::vector<string> key_parts = str_util::Split(key, ';');
|
||||
if (key_parts.size() != 5) {
|
||||
LOG(WARNING) << "Bad key: " << key;
|
||||
} else {
|
||||
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
|
||||
key_parts[3], // tensor name
|
||||
key_parts[0], // src_device
|
||||
key_parts[2], // dst_device
|
||||
bytes);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
}
|
||||
|
||||
IssueRequest(req_copy ? req_copy : request, response,
|
||||
&grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
|
||||
call_opts);
|
||||
&grpc::WorkerService::Stub::AsyncRecvTensor,
|
||||
std::move(*cb_to_use), call_opts);
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
|
@ -349,11 +349,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
// Helper for RecvTensor. Validates "key" and returns the source
|
||||
// device in "*src_dev".
|
||||
Status PrepareRecvTensor(const string& key, Device** src_dev) {
|
||||
// Validate the key.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||
Device** src_dev) {
|
||||
// Figures out which device the tensor is hosted on.
|
||||
TF_RETURN_IF_ERROR(
|
||||
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
|
||||
@ -375,8 +372,12 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
const int64 step_id = call->request.step_id();
|
||||
const string& key = call->request.rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
Device* src_dev = nullptr;
|
||||
Status s = PrepareRecvTensor(key, &src_dev);
|
||||
if (s.ok()) {
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
return;
|
||||
@ -388,7 +389,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
// cancellation should abort the rendezvous.
|
||||
call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
||||
env_->rendezvous_mgr->RecvLocalAsync(
|
||||
step_id, key,
|
||||
step_id, parsed,
|
||||
[this, call, src_dev](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
|
@ -41,8 +41,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
: BaseRemoteRendezvous(env, step_id, false) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
@ -55,23 +54,49 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
// Used only to retrieve tensors from remote processes.
|
||||
class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
public:
|
||||
RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
|
||||
int64 step_id, const string& key,
|
||||
const string& remote_dev, Allocator* allocator,
|
||||
Device* dst_device)
|
||||
: wi_(wi),
|
||||
wc_(wc),
|
||||
remote_dev_(remote_dev),
|
||||
allocator_(allocator),
|
||||
dst_(dst_device) {
|
||||
RpcRecvTensorCall()
|
||||
: wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
|
||||
|
||||
void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id,
|
||||
StringPiece key, Allocator* allocator, Device* dst_device,
|
||||
const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
|
||||
wi_ = wi;
|
||||
wc_ = wc;
|
||||
allocator_ = allocator;
|
||||
dst_device_ = dst_device;
|
||||
recv_args_ = recv_args;
|
||||
done_ = std::move(done);
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key);
|
||||
req_.set_rendezvous_key(key.data(), key.size());
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
delete wi_;
|
||||
wi_ = nullptr;
|
||||
wc_ = nullptr;
|
||||
allocator_ = nullptr;
|
||||
dst_device_ = nullptr;
|
||||
// We don't clear opts_ and assume that Init will set up the state for
|
||||
// opts_ appropriately.
|
||||
req_.Clear();
|
||||
if (resp_.ByteSize() > 128) {
|
||||
// Clear memory from resp_ if it is too large
|
||||
RecvTensorResponse empty;
|
||||
resp_.Swap(&empty);
|
||||
} else {
|
||||
resp_.Clear();
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_ = Status::OK();
|
||||
}
|
||||
done_ = nullptr;
|
||||
}
|
||||
|
||||
~RpcRecvTensorCall() override { delete wi_; }
|
||||
|
||||
void Start(std::function<void()> recv_done) override {
|
||||
StartRTCall(recv_done);
|
||||
StartRTCall(std::move(recv_done));
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
@ -93,6 +118,10 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
|
||||
bool is_dead() const { return resp_.is_dead(); }
|
||||
|
||||
Device* dst_device() const { return dst_device_; }
|
||||
const Rendezvous::Args& recv_args() const { return recv_args_; }
|
||||
const Rendezvous::DoneCallback& done() const { return done_; }
|
||||
|
||||
private:
|
||||
// Start the main RecvTensor call, checking for an async abort.
|
||||
void StartRTCall(std::function<void()> recv_done) {
|
||||
@ -100,7 +129,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
nullptr /* TensorBufAllocator */,
|
||||
// done callback
|
||||
[this, recv_done](const Status& s) {
|
||||
{
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
@ -110,12 +139,13 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
|
||||
WorkerInterface* wi_; // Owned.
|
||||
WorkerCacheInterface* wc_; // Not owned.
|
||||
string remote_dev_;
|
||||
Allocator* allocator_;
|
||||
Device* dst_;
|
||||
Device* dst_device_;
|
||||
CallOptions opts_;
|
||||
RecvTensorRequest req_;
|
||||
RecvTensorResponse resp_;
|
||||
Rendezvous::Args recv_args_;
|
||||
Rendezvous::DoneCallback done_;
|
||||
|
||||
mutable mutex mu_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
@ -123,10 +153,53 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
|
||||
};
|
||||
|
||||
namespace {
|
||||
class RpcRecvTensorFreeList {
|
||||
public:
|
||||
RpcRecvTensorFreeList() {}
|
||||
~RpcRecvTensorFreeList() {
|
||||
for (int i = 0; i < objects_.size(); i++) {
|
||||
delete objects_[i];
|
||||
}
|
||||
}
|
||||
|
||||
RpcRecvTensorCall* New() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!objects_.empty()) {
|
||||
RpcRecvTensorCall* result = objects_.back();
|
||||
objects_.pop_back();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return new RpcRecvTensorCall;
|
||||
}
|
||||
|
||||
void Release(RpcRecvTensorCall* obj) {
|
||||
obj->Reset();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (objects_.size() < kMaxObjects) {
|
||||
objects_.push_back(obj);
|
||||
return;
|
||||
}
|
||||
}
|
||||
delete obj;
|
||||
}
|
||||
|
||||
private:
|
||||
static const int kMaxObjects = 1000;
|
||||
|
||||
mutex mu_;
|
||||
std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
static RpcRecvTensorFreeList call_freelist_;
|
||||
}
|
||||
|
||||
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const string& key, const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args, DoneCallback done) {
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
|
||||
// key.src_device identifies a remote device.
|
||||
@ -137,11 +210,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
}
|
||||
// TODO(jeff): Consider checking for a valid worker_cache during the
|
||||
// constructor of RpcRemoteRendezvous, rather than here, to simplify
|
||||
// the twisty logic below.
|
||||
WorkerCacheInterface* worker_cache = env_->worker_cache;
|
||||
if (s.ok() && worker_cache == nullptr) {
|
||||
s = errors::Internal("No remote worker cache available.");
|
||||
}
|
||||
WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
|
||||
WorkerInterface* rwi =
|
||||
(worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr);
|
||||
if (s.ok() && rwi == nullptr) {
|
||||
s = errors::Internal("No worker known as ", src_worker);
|
||||
}
|
||||
@ -157,15 +234,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
RpcRecvTensorCall* call =
|
||||
new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
|
||||
parsed.src_device, allocator, dst_device);
|
||||
RpcRecvTensorCall* call = call_freelist_.New();
|
||||
|
||||
call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator,
|
||||
dst_device, recv_args, std::move(done));
|
||||
|
||||
// Record "call" in active_ so that it can be aborted cleanly.
|
||||
RegisterCall(call);
|
||||
|
||||
// Start "call".
|
||||
call->Start([this, call, parsed, recv_args, done]() {
|
||||
call->Start([this, call]() {
|
||||
// Removes "call" from active_. Prevent StartAbort().
|
||||
DeregisterCall(call);
|
||||
// If StartAbort was called prior to DeregisterCall, then the
|
||||
@ -173,24 +251,19 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
Status s = call->status();
|
||||
Tensor val;
|
||||
if (s.ok()) {
|
||||
Device* dst_device;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (s.ok()) {
|
||||
s = dst_device->MakeTensorFromProto(call->tensor_proto(),
|
||||
recv_args.alloc_attrs, &val);
|
||||
}
|
||||
s = call->dst_device()->MakeTensorFromProto(
|
||||
call->tensor_proto(), call->recv_args().alloc_attrs, &val);
|
||||
}
|
||||
done(s, Args(), recv_args, val, call->is_dead());
|
||||
delete call;
|
||||
call->done()(s, Args(), call->recv_args(), val, call->is_dead());
|
||||
call_freelist_.Release(call);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
const WorkerEnv* worker_env) {
|
||||
return new RpcRemoteRendezvous(worker_env, step_id);
|
||||
}
|
||||
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -40,15 +40,21 @@ string V(const Tensor& tensor) {
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey MakeKey(const string& s) {
|
||||
Rendezvous::ParsedKey key;
|
||||
CHECK(Rendezvous::ParseKey(s, &key).ok());
|
||||
return key;
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
@ -69,9 +75,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) {
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{ // Explicit Abort().
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
@ -105,9 +111,9 @@ TEST(RpcRendezvousMgrTest, CleanupAll) {
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
@ -139,9 +145,9 @@ TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace Eigen {
|
||||
@ -80,7 +81,7 @@ class DeviceContext : public core::RefCounted {
|
||||
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
|
||||
// to be of the same size as "device_tensor".
|
||||
virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& tensor_name, Device* device,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) {
|
||||
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -30,6 +31,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
|
||||
const char* b_base = b.buf_.data();
|
||||
buf_ = b.buf_;
|
||||
src_device.set(buf_.data() + (b.src_device.data() - b_base),
|
||||
b.src_device.size());
|
||||
src = b.src;
|
||||
src_incarnation = b.src_incarnation;
|
||||
dst_device.set(buf_.data() + (b.dst_device.data() - b_base),
|
||||
b.dst_device.size());
|
||||
dst = b.dst;
|
||||
edge_name.set(buf_.data() + (b.edge_name.data() - b_base),
|
||||
b.edge_name.size());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/* static */
|
||||
string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
|
||||
const string& dst_device, const string& name,
|
||||
@ -66,7 +82,8 @@ static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
|
||||
|
||||
/* static */
|
||||
Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
StringPiece s(key);
|
||||
out->buf_ = key; // Make a copy that our StringPieces can point at
|
||||
StringPiece s(out->buf_);
|
||||
StringPiece parts[5];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
parts[i] = ConsumeNextPart(&s, ';');
|
||||
@ -77,9 +94,9 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
|
||||
DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
|
||||
!parts[3].empty()) {
|
||||
out->src_device.assign(parts[0].data(), parts[0].size());
|
||||
out->dst_device.assign(parts[2].data(), parts[2].size());
|
||||
out->edge_name.assign(parts[3].data(), parts[3].size());
|
||||
out->src_device.set(parts[0].data(), parts[0].size());
|
||||
out->dst_device.set(parts[2].data(), parts[2].size());
|
||||
out->edge_name.set(parts[3].data(), parts[3].size());
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid rendezvous key: ", key);
|
||||
@ -87,8 +104,8 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
|
||||
Rendezvous::~Rendezvous() {}
|
||||
|
||||
Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val,
|
||||
bool* is_dead) {
|
||||
Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
|
||||
Tensor* val, bool* is_dead) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvAsync(key, recv_args,
|
||||
@ -109,18 +126,19 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
explicit LocalRendezvousImpl(bool tolerate_dup_recv)
|
||||
: tolerate_dup_recv_(tolerate_dup_recv) {}
|
||||
|
||||
Status Send(const string& key, const Args& send_args, const Tensor& val,
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
VLOG(2) << "Send " << this << " " << key;
|
||||
DoneCallback waiter = nullptr;
|
||||
Args recv_args;
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) {
|
||||
return status_;
|
||||
}
|
||||
Item* item = nullptr;
|
||||
Table::iterator iter = table_.find(key);
|
||||
Table::iterator iter = table_.find(key_hash);
|
||||
if (iter == table_.end()) {
|
||||
// There is no waiter for this message. Insert the message
|
||||
// into the waiters table. The waiter will pick it up when
|
||||
@ -138,19 +156,24 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// The allocator attributes of item->value.
|
||||
item->send_alloc_attrs = send_args.alloc_attrs;
|
||||
|
||||
CHECK(table_.insert({key, item}).second);
|
||||
CHECK(table_.insert({key_hash, item}).second);
|
||||
return Status::OK();
|
||||
} else {
|
||||
item = iter->second;
|
||||
|
||||
if (item->waiter == nullptr) {
|
||||
// There is already a message in the table under the key.
|
||||
// Should not happen unless it has a waiter.
|
||||
return errors::Aborted("Duplicated send: ", key);
|
||||
return errors::Aborted("Duplicated send: ", key.FullKey());
|
||||
}
|
||||
// Mark item as complete.
|
||||
item->has_been_recvd = true;
|
||||
waiter = item->waiter;
|
||||
item->waiter = nullptr;
|
||||
|
||||
// Get item->waiter function into waiter and set item->waiter to null
|
||||
std::swap(item->waiter, waiter);
|
||||
DCHECK(item->waiter == nullptr);
|
||||
DCHECK(waiter != nullptr);
|
||||
|
||||
// The ref on recv_dev_context transfers below.
|
||||
recv_args.device_context = item->recv_dev_context;
|
||||
recv_args.alloc_attrs = item->recv_alloc_attrs;
|
||||
@ -173,9 +196,10 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecvAsync(const string& key, const Args& recv_args,
|
||||
void RecvAsync(const ParsedKey& key, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
VLOG(2) << "Recv " << this << " " << key;
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
@ -184,13 +208,13 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
Table::iterator iter = table_.find(key);
|
||||
Table::iterator iter = table_.find(key_hash);
|
||||
if (iter != table_.end()) {
|
||||
Item* item = iter->second;
|
||||
if (item->has_been_recvd && !tolerate_dup_recv_) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
|
||||
Tensor(), false);
|
||||
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
|
||||
recv_args, Tensor(), false);
|
||||
} else if (item->waiter == nullptr || tolerate_dup_recv_) {
|
||||
// A message has already arrived and is stored in the table
|
||||
// under this key. Consumes the message and invokes the done
|
||||
@ -218,8 +242,8 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// Already have a waiter in the waiters table under this key,
|
||||
// which should not happen.
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
|
||||
Tensor(), false);
|
||||
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
|
||||
recv_args, Tensor(), false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -227,13 +251,13 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// waiting table. The done closure will be invoked when the
|
||||
// message arrives.
|
||||
Item* item = new Item;
|
||||
item->waiter = done;
|
||||
item->waiter = std::move(done);
|
||||
item->recv_alloc_attrs = recv_args.alloc_attrs;
|
||||
if (recv_args.device_context) {
|
||||
item->recv_dev_context = recv_args.device_context;
|
||||
item->recv_dev_context->Ref();
|
||||
}
|
||||
CHECK(table_.insert({key, item}).second);
|
||||
CHECK(table_.insert({key_hash, item}).second);
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
@ -280,7 +304,12 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
}
|
||||
};
|
||||
typedef std::unordered_map<string, Item*> Table;
|
||||
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
|
||||
static uint64 KeyHash(const StringPiece& k) {
|
||||
return Hash64(k.data(), k.size());
|
||||
}
|
||||
|
||||
typedef std::unordered_map<uint64, Item*> Table;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
|
@ -54,12 +54,22 @@ class Rendezvous : public core::RefCounted {
|
||||
// Parses the key constructed by CreateKey and parse src/dst device
|
||||
// names into structures respectively.
|
||||
struct ParsedKey {
|
||||
string src_device;
|
||||
StringPiece src_device;
|
||||
DeviceNameUtils::ParsedName src;
|
||||
uint64 src_incarnation = 0;
|
||||
string dst_device;
|
||||
StringPiece dst_device;
|
||||
DeviceNameUtils::ParsedName dst;
|
||||
string edge_name;
|
||||
StringPiece edge_name;
|
||||
|
||||
ParsedKey() {}
|
||||
ParsedKey(const ParsedKey& b) { *this = b; }
|
||||
|
||||
ParsedKey& operator=(const ParsedKey& b);
|
||||
StringPiece FullKey() const { return buf_; }
|
||||
|
||||
private:
|
||||
friend class Rendezvous;
|
||||
string buf_;
|
||||
};
|
||||
static Status ParseKey(const string& key, ParsedKey* out);
|
||||
|
||||
@ -74,7 +84,7 @@ class Rendezvous : public core::RefCounted {
|
||||
// Send/Recv on the same worker.
|
||||
//
|
||||
// Send() never blocks.
|
||||
virtual Status Send(const string& key, const Args& args, const Tensor& val,
|
||||
virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
|
||||
const bool is_dead) = 0;
|
||||
|
||||
// Callback provided by a tensor consumer waiting on the rendezvous.
|
||||
@ -84,13 +94,15 @@ class Rendezvous : public core::RefCounted {
|
||||
// receiver, which may be needed when a non-CPU device is in use
|
||||
// by either side.
|
||||
typedef std::function<void(const Status&, const Args&, const Args&,
|
||||
const Tensor&, const bool)> DoneCallback;
|
||||
const Tensor&, const bool)>
|
||||
DoneCallback;
|
||||
|
||||
virtual void RecvAsync(const string& key, const Args& args,
|
||||
virtual void RecvAsync(const ParsedKey& key, const Args& args,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RecvAsync.
|
||||
Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead);
|
||||
Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
|
||||
bool* is_dead);
|
||||
|
||||
// Aborts all pending and future Send/Recv with the given "status".
|
||||
//
|
||||
|
@ -96,13 +96,29 @@ string V(const Tensor& tensor) {
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
const char* kFoo = "/cpu:0;1;/cpu:1;foo;1;2";
|
||||
const char* kBar = "/gpu:0;2;/gpu:1;bar;1;2";
|
||||
|
||||
Rendezvous::ParsedKey MakeKey(const string& name) {
|
||||
string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/GPU:0", name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey k;
|
||||
TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
|
||||
return k;
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey KeyFoo() { return MakeKey("foo"); }
|
||||
Rendezvous::ParsedKey KeyBar() { return MakeKey("bar"); }
|
||||
|
||||
TEST_F(LocalRendezvousTest, SendRecv) {
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false)));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Send(KeyFoo(), args, V("hello"), false)));
|
||||
Tensor val(DT_STRING);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
|
||||
EXPECT_EQ("hello", V(val));
|
||||
}
|
||||
|
||||
@ -110,12 +126,12 @@ TEST_F(LocalRendezvousTest, RecvSend) {
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(10000);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
});
|
||||
Tensor val(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
|
||||
EXPECT_EQ("hello", V(val));
|
||||
}
|
||||
|
||||
@ -124,16 +140,17 @@ TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) {
|
||||
Tensor t(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
|
||||
});
|
||||
Env::Default()->SleepForMicroseconds(1000000);
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
|
||||
EXPECT_EQ("secret msg", V(val));
|
||||
}
|
||||
|
||||
@ -142,17 +159,18 @@ TEST_F(LocalRendezvousTest, DuplicateSerialRecv) {
|
||||
Tensor t(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
|
||||
});
|
||||
Env::Default()->SleepForMicroseconds(1000000);
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
|
||||
EXPECT_EQ("secret msg", V(val));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
}
|
||||
|
||||
// A simple structure that behaves a bit like a blocking counter. The
|
||||
@ -174,7 +192,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
|
||||
random::SimplePhilox rnd(&philox);
|
||||
Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send(strings::StrCat(i), args,
|
||||
TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
|
||||
V(strings::StrCat(i)), false));
|
||||
});
|
||||
SchedClosure([this, &state, i]() {
|
||||
@ -184,7 +202,8 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead));
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(MakeKey(strings::StrCat(i)), args, &val, &val_dead));
|
||||
EXPECT_EQ(strings::StrCat(i), V(val));
|
||||
bool done = false;
|
||||
{
|
||||
@ -212,7 +231,7 @@ TEST_F(LocalRendezvousTest, RecvAbort) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
Status status = rendez_->Recv("foo", args, &val, &val_dead);
|
||||
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
|
||||
EXPECT_TRUE(errors::IsAborted(status));
|
||||
}
|
||||
|
||||
@ -228,7 +247,7 @@ TEST_F(LocalRendezvousTest, RecvSleepAbort) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
Status status = rendez_->Recv("foo", args, &val, &val_dead);
|
||||
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
|
||||
EXPECT_TRUE(errors::IsAborted(status));
|
||||
}
|
||||
|
||||
@ -237,8 +256,9 @@ TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead)));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
}
|
||||
|
||||
class DummyDeviceContext : public DeviceContext {
|
||||
@ -255,15 +275,15 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
|
||||
Rendezvous::Args args;
|
||||
args.device_context = new DummyDeviceContext(123);
|
||||
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
|
||||
Notification n;
|
||||
Rendezvous::Args args1;
|
||||
args1.device_context = new DummyDeviceContext(1);
|
||||
rendez_->RecvAsync("foo", args1, [&n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
rendez_->RecvAsync(KeyFoo(), args1, [&n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
CHECK_EQ(123,
|
||||
dynamic_cast<const DummyDeviceContext*>(send_args.device_context)
|
||||
->stream_id());
|
||||
@ -284,8 +304,8 @@ static void BM_SendRecv(int iters) {
|
||||
Status s;
|
||||
if (iters > 0) {
|
||||
while (iters--) {
|
||||
s = rendez->Send("foo", args, orig, is_dead);
|
||||
s = rendez->Recv("foo", args, &val, &is_dead);
|
||||
s = rendez->Send(KeyFoo(), args, orig, is_dead);
|
||||
s = rendez->Recv(KeyFoo(), args, &val, &is_dead);
|
||||
}
|
||||
CHECK_EQ(V(val), V(orig));
|
||||
}
|
||||
@ -307,8 +327,8 @@ static void BM_RecvSend(int iters) {
|
||||
Rendezvous::Args args;
|
||||
Status s;
|
||||
for (int i = 0; i < iters / 2; ++i) {
|
||||
s = rendez->Recv("foo", args, &foo, &is_dead);
|
||||
s = rendez->Send("bar", args, bar, is_dead);
|
||||
s = rendez->Recv(KeyFoo(), args, &foo, &is_dead);
|
||||
s = rendez->Send(KeyBar(), args, bar, is_dead);
|
||||
}
|
||||
CHECK_EQ("foo", V(foo));
|
||||
});
|
||||
@ -318,8 +338,8 @@ static void BM_RecvSend(int iters) {
|
||||
Rendezvous::Args args;
|
||||
Status s;
|
||||
for (int i = 0; i < iters / 2; ++i) {
|
||||
s = rendez->Send("foo", args, foo, is_dead);
|
||||
s = rendez->Recv("bar", args, &bar, &is_dead);
|
||||
s = rendez->Send(KeyFoo(), args, foo, is_dead);
|
||||
s = rendez->Recv(KeyBar(), args, &bar, &is_dead);
|
||||
}
|
||||
CHECK_EQ("bar", V(bar));
|
||||
delete pool;
|
||||
|
@ -118,6 +118,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "fill_functor",
|
||||
srcs = ["fill_functor.cc"],
|
||||
hdrs = ["fill_functor.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
@ -1032,7 +1033,6 @@ tf_kernel_libraries(
|
||||
],
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":constant_op",
|
||||
":fill_functor",
|
||||
":transpose_functor",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -1541,7 +1541,6 @@ tf_kernel_libraries(
|
||||
],
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":constant_op",
|
||||
":cwise_op",
|
||||
":fill_functor",
|
||||
":scatter_op",
|
||||
@ -1723,6 +1722,7 @@ filegroup(
|
||||
"dense_update_ops.cc",
|
||||
"dense_update_ops.h",
|
||||
"example_parsing_ops.cc",
|
||||
"fill_functor.cc",
|
||||
"fill_functor.h",
|
||||
"gather_op.cc",
|
||||
"gather_op.h",
|
||||
|
@ -114,37 +114,6 @@ struct FillFunctor<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
// Partial specialization of SetZeroFunctor<Device=CPUDevice, T>.
|
||||
template <typename T>
|
||||
struct SetZeroFunctor<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out) {
|
||||
out.device(d) = out.constant(T(0));
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization of SetZeroFunctor<Device=CPUDevice, T=string>.
|
||||
template <>
|
||||
struct SetZeroFunctor<CPUDevice, string> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<string>::Flat out) {
|
||||
out.device(d) = out.constant(string());
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor<CPUDevice, T>;
|
||||
DEFINE_SETZERO_CPU(Eigen::half);
|
||||
DEFINE_SETZERO_CPU(float);
|
||||
DEFINE_SETZERO_CPU(double);
|
||||
DEFINE_SETZERO_CPU(uint8);
|
||||
DEFINE_SETZERO_CPU(int8);
|
||||
DEFINE_SETZERO_CPU(uint16);
|
||||
DEFINE_SETZERO_CPU(int16);
|
||||
DEFINE_SETZERO_CPU(int32);
|
||||
DEFINE_SETZERO_CPU(int64);
|
||||
DEFINE_SETZERO_CPU(complex64);
|
||||
DEFINE_SETZERO_CPU(complex128);
|
||||
DEFINE_SETZERO_CPU(string);
|
||||
#undef DEFINE_SETZERO_CPU
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
@ -35,7 +35,6 @@ struct scalar_fmod2_op {
|
||||
return std::fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct functor_traits<scalar_fmod2_op<T>> {
|
||||
enum {
|
||||
@ -44,6 +43,24 @@ struct functor_traits<scalar_fmod2_op<T>> {
|
||||
};
|
||||
};
|
||||
|
||||
// TODO(rmlarsen): This is a workaround for upstream change
|
||||
// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
|
||||
// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
|
||||
// version of the latter. Remove once we upgrade to Eigen 3.3.
|
||||
template <typename Scalar, typename Exponent>
|
||||
struct scalar_binary_pow_op_google {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
|
||||
EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
|
||||
const Exponent& b) const {
|
||||
return numext::pow(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Exponent>
|
||||
struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
|
||||
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
|
||||
};
|
||||
|
||||
template <typename T, typename DivOrMod>
|
||||
struct safe_div_or_mod_op {
|
||||
static_assert(std::is_integral<T>::value, "Integer type expected");
|
||||
@ -477,7 +494,7 @@ struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
|
||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct maximum : base<T, Eigen::internal::scalar_max_op<T> > {};
|
||||
|
@ -61,22 +61,23 @@ class DrawBoundingBoxesOp : public OpKernel {
|
||||
|
||||
for (int64 b = 0; b < batch_size; ++b) {
|
||||
const int64 num_boxes = boxes.dim_size(1);
|
||||
const auto tboxes = boxes.tensor<T, 3>();
|
||||
|
||||
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
||||
const int64 min_box_row =
|
||||
(height - 1) * boxes.tensor<float, 3>()(b, bb, 0);
|
||||
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
||||
const int64 min_box_row_clamp =
|
||||
std::max<int64>(min_box_row, 0);
|
||||
const int64 max_box_row =
|
||||
(height - 1) * boxes.tensor<float, 3>()(b, bb, 2);
|
||||
static_cast<float>(tboxes(b, bb, 2)) * (height - 1);
|
||||
const int64 max_box_row_clamp =
|
||||
std::min<int64>(max_box_row, height - 1);
|
||||
const int64 min_box_col =
|
||||
(width - 1) * boxes.tensor<float, 3>()(b, bb, 1);
|
||||
static_cast<float>(tboxes(b, bb, 1)) * (width - 1);
|
||||
const int64 min_box_col_clamp =
|
||||
std::max<int64>(min_box_col, 0);
|
||||
const int64 max_box_col =
|
||||
(width - 1) * boxes.tensor<float, 3>()(b, bb, 3);
|
||||
static_cast<float>(tboxes(b, bb, 3)) * (width - 1);
|
||||
const int64 max_box_col_clamp =
|
||||
std::min<int64>(max_box_col, width - 1);
|
||||
|
||||
@ -121,22 +122,22 @@ class DrawBoundingBoxesOp : public OpKernel {
|
||||
// Draw top line.
|
||||
if (min_box_row >= 0) {
|
||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||
canvas(b, min_box_row, j, 0) = T(nanf(""));
|
||||
canvas(b, min_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
}
|
||||
// Draw bottom line.
|
||||
if (max_box_row < height) {
|
||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||
canvas(b, max_box_row, j, 0) = T(nanf(""));
|
||||
canvas(b, max_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
}
|
||||
// Draw left line.
|
||||
if (min_box_col >= 0) {
|
||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||
canvas(b, i, min_box_col, 0) = T(nanf(""));
|
||||
canvas(b, i, min_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
}
|
||||
// Draw right line.
|
||||
if (max_box_col < width) {
|
||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||
canvas(b, i, max_box_col, 0) = T(nanf(""));
|
||||
canvas(b, i, max_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
57
tensorflow/core/kernels/fill_functor.cc
Normal file
57
tensorflow/core/kernels/fill_functor.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
void SetZeroFunctor<Eigen::ThreadPoolDevice, T>::operator()(
|
||||
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
|
||||
out.device(d) = out.constant(T(0));
|
||||
}
|
||||
|
||||
void SetZeroFunctor<Eigen::ThreadPoolDevice, string>::operator()(
|
||||
const Eigen::ThreadPoolDevice& d, typename TTypes<string>::Flat out) {
|
||||
out.device(d) = out.constant(string());
|
||||
}
|
||||
|
||||
// Explicit instantiations.
|
||||
#define DEFINE_SETZERO_CPU(T) \
|
||||
template struct SetZeroFunctor<Eigen::ThreadPoolDevice, T>;
|
||||
DEFINE_SETZERO_CPU(bool);
|
||||
DEFINE_SETZERO_CPU(Eigen::half);
|
||||
DEFINE_SETZERO_CPU(float);
|
||||
DEFINE_SETZERO_CPU(double);
|
||||
DEFINE_SETZERO_CPU(uint8);
|
||||
DEFINE_SETZERO_CPU(int8);
|
||||
DEFINE_SETZERO_CPU(uint16);
|
||||
DEFINE_SETZERO_CPU(int16);
|
||||
DEFINE_SETZERO_CPU(int32);
|
||||
DEFINE_SETZERO_CPU(int64);
|
||||
DEFINE_SETZERO_CPU(complex64);
|
||||
DEFINE_SETZERO_CPU(complex128);
|
||||
DEFINE_SETZERO_CPU(string);
|
||||
#undef DEFINE_SETZERO_CPU
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
@ -16,8 +16,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
|
||||
#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
@ -35,6 +38,19 @@ struct SetZeroFunctor {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat out);
|
||||
};
|
||||
|
||||
// Partial specialization of SetZeroFunctor<Device=Eigen::ThreadPoolDevice, T>.
|
||||
template <typename T>
|
||||
struct SetZeroFunctor<Eigen::ThreadPoolDevice, T> {
|
||||
void operator()(const Eigen::ThreadPoolDevice& d,
|
||||
typename TTypes<T>::Flat out);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SetZeroFunctor<Eigen::ThreadPoolDevice, string> {
|
||||
void operator()(const Eigen::ThreadPoolDevice& d,
|
||||
typename TTypes<string>::Flat out);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -64,19 +64,21 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
const std::vector<string> tensor_names = {
|
||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64"};
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||
"tensor_half"};
|
||||
|
||||
// We first need to write a tensor using the save_op
|
||||
{
|
||||
// Initialize an operation
|
||||
NodeDef save;
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "Save")
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
|
||||
DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8,
|
||||
DT_INT16, DT_STRING, DT_COMPLEX64}))
|
||||
.Finalize(&save));
|
||||
TF_ASSERT_OK(
|
||||
NodeDefBuilder("myop", "Save")
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
|
||||
DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_STRING,
|
||||
DT_COMPLEX64, DT_HALF}))
|
||||
.Finalize(&save));
|
||||
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
@ -156,7 +158,12 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
TensorShape({2, 3}),
|
||||
[](int x) -> complex64 { return complex64(100 + x, 200 + x); });
|
||||
inputs.push_back({nullptr, &input_13});
|
||||
|
||||
// Input #14 is a 2-d half tensor
|
||||
Tensor input_14 =
|
||||
MakeInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
|
||||
return static_cast<Eigen::half>(x) / Eigen::half(5);
|
||||
});
|
||||
inputs.push_back({nullptr, &input_14});
|
||||
OpKernelContext::Params params;
|
||||
params.device = device.get();
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
@ -321,6 +328,19 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
EXPECT_EQ(complex64(100 + i, 200 + i), output->flat<complex64>()(i));
|
||||
}
|
||||
}
|
||||
// The 2-d half tensor
|
||||
{
|
||||
MakeRestoreOp(DT_HALF);
|
||||
(*mutable_input(1).tensor).scalar<string>()() = tensor_names[12];
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor* output = GetOutput(0);
|
||||
TensorShape expected({2, 4});
|
||||
EXPECT_TRUE(output->shape().IsSameSize(expected));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(5),
|
||||
output->flat<Eigen::half>()(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class RestoreSliceOpTest : public OpsTestBase {
|
||||
|
@ -44,7 +44,7 @@ class SaveOpTest : public OpsTestBase {
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
|
||||
DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64,
|
||||
DT_STRING, DT_COMPLEX64, DT_COMPLEX128}))
|
||||
DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF}))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
@ -53,10 +53,10 @@ class SaveOpTest : public OpsTestBase {
|
||||
TEST_F(SaveOpTest, Simple) {
|
||||
const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
|
||||
const string tensornames[] = {
|
||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||
"tensor_complex128"};
|
||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||
"tensor_complex128", "tensor_half"};
|
||||
|
||||
MakeOp();
|
||||
// Add a file name
|
||||
@ -64,7 +64,7 @@ TEST_F(SaveOpTest, Simple) {
|
||||
[&filename](int x) -> string { return filename; });
|
||||
|
||||
// Add the tensor names
|
||||
AddInput<string>(TensorShape({13}),
|
||||
AddInput<string>(TensorShape({14}),
|
||||
[&tensornames](int x) -> string { return tensornames[x]; });
|
||||
|
||||
// Add a 1-d bool tensor
|
||||
@ -116,6 +116,10 @@ TEST_F(SaveOpTest, Simple) {
|
||||
return complex128(100 + x, 200 + x);
|
||||
});
|
||||
|
||||
// Add a 2-d half tensor
|
||||
AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
|
||||
return static_cast<Eigen::half>(x) / Eigen::half(2);
|
||||
});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check that the checkpoint file is properly written
|
||||
@ -363,6 +367,24 @@ TEST_F(SaveOpTest, Simple) {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// The 2-d half tensor
|
||||
TensorShape shape;
|
||||
DataType type;
|
||||
EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type));
|
||||
TensorShape expected({2, 4});
|
||||
EXPECT_TRUE(shape.IsSameSize(expected));
|
||||
EXPECT_EQ(DT_HALF, type);
|
||||
|
||||
// We expect the tensor value to be correct.
|
||||
TensorSlice s = TensorSlice::ParseOrDie("-:-");
|
||||
Eigen::half data[8];
|
||||
std::fill_n(data, 8, Eigen::half(0));
|
||||
EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class SaveSlicesOpTest : public OpsTestBase {
|
||||
|
@ -32,10 +32,11 @@ static string GetRendezvousKeyPrefix(const string& send_device,
|
||||
recv_device, ";", tensor_name);
|
||||
}
|
||||
|
||||
static string GetRendezvousKey(const string& key_prefix,
|
||||
const FrameAndIter& frame_iter) {
|
||||
return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
|
||||
frame_iter.iter_id);
|
||||
static void GetRendezvousKey(const string& key_prefix,
|
||||
const FrameAndIter& frame_iter, string* key) {
|
||||
key->clear();
|
||||
strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":",
|
||||
frame_iter.iter_id);
|
||||
}
|
||||
|
||||
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -57,9 +58,13 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
|
||||
string key;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
|
||||
VLOG(2) << "Send " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
// The device context may be passed between the Send/Recv
|
||||
// boundary, so that the device context used to produce the Tensor
|
||||
// is used when performing the copy on the recv side (which may be
|
||||
@ -67,9 +72,8 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->input_alloc_attr(0);
|
||||
Status s =
|
||||
ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
|
||||
ctx->SetStatus(s);
|
||||
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
|
||||
ctx->is_input_dead()));
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
@ -98,16 +102,21 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
|
||||
string key;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
|
||||
VLOG(2) << "Recv " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(key, &parsed), done);
|
||||
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->output_alloc_attr(0);
|
||||
ctx->rendezvous()->RecvAsync(
|
||||
key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
parsed, args,
|
||||
[ctx, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& val,
|
||||
bool is_dead) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
// 'ctx' allocates the output tensor of the expected type. The
|
||||
|
@ -74,15 +74,14 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
|
||||
int num_threads)
|
||||
: Eigen::ThreadPoolTempl<EigenEnvironment>(
|
||||
num_threads, EigenEnvironment(env, thread_options, name)),
|
||||
num_threads_(num_threads) {}
|
||||
num_threads, EigenEnvironment(env, thread_options, name)) {}
|
||||
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn) {
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
CHECK_GE(total, 0);
|
||||
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
||||
Eigen::ThreadPoolDevice device(this, num_threads_);
|
||||
Eigen::ThreadPoolDevice device(this, this->NumThreads());
|
||||
device.parallelFor(
|
||||
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
||||
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
||||
@ -90,10 +89,6 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
CHECK(0); // should not be used with the old thread pool
|
||||
#endif
|
||||
}
|
||||
|
||||
int NumThreads() const { return num_threads_; };
|
||||
|
||||
const int num_threads_;
|
||||
};
|
||||
|
||||
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
|
||||
@ -120,5 +115,7 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
|
||||
|
||||
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
|
||||
|
||||
int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
|
||||
|
||||
} // namespace thread
|
||||
} // namespace tensorflow
|
||||
|
@ -57,6 +57,10 @@ class ThreadPool {
|
||||
// Returns the number of threads in the pool.
|
||||
int NumThreads() const;
|
||||
|
||||
// Returns current thread id between 0 and NumThreads() - 1, if called from a
|
||||
// thread in the pool. Returns -1 otherwise.
|
||||
int CurrentThreadId() const;
|
||||
|
||||
struct Impl;
|
||||
|
||||
private:
|
||||
|
@ -1013,7 +1013,7 @@ Computes softmax activations.
|
||||
|
||||
For each batch `i` and class `j` we have
|
||||
|
||||
softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
|
||||
softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
|
||||
|
||||
logits: 2-D with shape `[batch_size, num_classes]`.
|
||||
softmax: Same shape as `logits`.
|
||||
|
@ -11755,7 +11755,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Computes softmax activations."
|
||||
description: "For each batch `i` and class `j` we have\n\n softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))"
|
||||
description: "For each batch `i` and class `j` we have\n\n softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))"
|
||||
}
|
||||
op {
|
||||
name: "SoftmaxCrossEntropyWithLogits"
|
||||
|
@ -101,7 +101,7 @@ def tf_additional_stream_executor_srcs():
|
||||
return ["platform/default/stream_executor.h"]
|
||||
|
||||
def tf_additional_cupti_wrapper_deps():
|
||||
return [":cupti_wrapper_default"]
|
||||
return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
|
||||
|
||||
def tf_additional_test_deps():
|
||||
return []
|
||||
|
23
tensorflow/core/platform/default/gpu/BUILD
Normal file
23
tensorflow/core/platform/default/gpu/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "cupti_wrapper",
|
||||
srcs = [
|
||||
"cupti_wrapper.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cupti_wrapper.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
cuda_deps = [
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//third_party/gpus/cuda:cuda_headers",
|
||||
"//third_party/gpus/cuda:cupti_headers",
|
||||
],
|
||||
data = ["//third_party/gpus/cuda:cupti_dsos"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
@ -104,22 +104,29 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
}
|
||||
StringPiece tmp;
|
||||
while (!fullname.empty()) {
|
||||
bool progress = false;
|
||||
if (str_util::ConsumePrefix(&fullname, "/job:")) {
|
||||
p->has_job = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/replica:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/replica:")) {
|
||||
p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/task:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/task:")) {
|
||||
p->has_task = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/device:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/device:")) {
|
||||
p->has_type = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
|
||||
return false;
|
||||
@ -132,24 +139,31 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress = true;
|
||||
}
|
||||
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
p->has_type = true;
|
||||
p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
|
||||
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
|
||||
p->has_type = true;
|
||||
p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
|
||||
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
progress = true;
|
||||
}
|
||||
|
||||
if (!progress) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -340,11 +354,22 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
|
||||
string* device) {
|
||||
ParsedName pn;
|
||||
if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
|
||||
*task = strings::StrCat(
|
||||
(pn.has_job ? strings::StrCat("/job:", pn.job) : ""),
|
||||
(pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""),
|
||||
(pn.has_task ? strings::StrCat("/task:", pn.task) : ""));
|
||||
*device = strings::StrCat(pn.type, ":", pn.id);
|
||||
task->clear();
|
||||
task->reserve(
|
||||
(pn.has_job ? (5 + pn.job.size()) : 0) +
|
||||
(pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
|
||||
(pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
|
||||
if (pn.has_job) {
|
||||
strings::StrAppend(task, "/job:", pn.job);
|
||||
}
|
||||
if (pn.has_replica) {
|
||||
strings::StrAppend(task, "/replica:", pn.replica);
|
||||
}
|
||||
if (pn.has_task) {
|
||||
strings::StrAppend(task, "/task:", pn.task);
|
||||
}
|
||||
device->clear();
|
||||
strings::StrAppend(device, pn.type, ":", pn.id);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -87,6 +87,44 @@ struct CopyThatWorksWithStringPointer<string> {
|
||||
}
|
||||
};
|
||||
|
||||
// Checkpointing of half is done by storing the raw 16 bits as a signed 32bit
|
||||
// integer. To restore the checkpoint we need to do the reverse operation by
|
||||
// reinterpreting the integer as a 16 bit float. This prevents us from using
|
||||
// the default cast operation.
|
||||
template <>
|
||||
struct CopyThatWorksWithStringPointer<Eigen::half> {
|
||||
template <typename SrcTensor, typename DstTensor, typename Shape>
|
||||
static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
|
||||
Shape d_start) {
|
||||
typedef typename SrcTensor::Index Index;
|
||||
static_assert(kTensorSliceMaxRank == 8,
|
||||
"If kTensorSliceMaxRank changes, modify the loop below.");
|
||||
for (Index i0 = 0; i0 < len[0]; i0++) {
|
||||
for (Index i1 = 0; i1 < len[1]; i1++) {
|
||||
for (Index i2 = 0; i2 < len[2]; i2++) {
|
||||
for (Index i3 = 0; i3 < len[3]; i3++) {
|
||||
for (Index i4 = 0; i4 < len[4]; i4++) {
|
||||
for (Index i5 = 0; i5 < len[5]; i5++) {
|
||||
for (Index i6 = 0; i6 < len[6]; i6++) {
|
||||
for (Index i7 = 0; i7 < len[7]; i7++) {
|
||||
d(d_start[0] + i0, d_start[1] + i1, d_start[2] + i2,
|
||||
d_start[3] + i3, d_start[4] + i4, d_start[5] + i5,
|
||||
d_start[6] + i6, d_start[7] + i7) =
|
||||
Eigen::internal::raw_uint16_to_half(
|
||||
s(s_start[0] + i0, s_start[1] + i1, s_start[2] + i2,
|
||||
s_start[3] + i3, s_start[4] + i4, s_start[5] + i5,
|
||||
s_start[6] + i6, s_start[7] + i7));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Given a tensor described by "shape", two slices "slice_s" and "slice_d",
|
||||
// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of
|
||||
// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of
|
||||
|
@ -30,9 +30,9 @@ CORPUS_FILENAME = "europarl-v6.fr-en.en"
|
||||
MAX_DOC_LENGTH = 10
|
||||
|
||||
def training_data(filename):
|
||||
f = open(filename)
|
||||
for line in f:
|
||||
yield line
|
||||
f = open(filename)
|
||||
for line in f:
|
||||
yield line
|
||||
|
||||
|
||||
def iter_docs(docs):
|
||||
@ -67,34 +67,34 @@ HIDDEN_SIZE = 10
|
||||
|
||||
|
||||
def seq_autoencoder(X, y):
|
||||
"""Sequence auto-encoder with RNN."""
|
||||
inputs = learn.ops.one_hot_matrix(X, 256)
|
||||
in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
|
||||
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
|
||||
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
|
||||
decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
|
||||
return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
|
||||
"""Sequence auto-encoder with RNN."""
|
||||
inputs = learn.ops.one_hot_matrix(X, 256)
|
||||
in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
|
||||
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
|
||||
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
|
||||
decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
|
||||
return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
|
||||
|
||||
|
||||
def get_language_model(hidden_size):
|
||||
"""Returns a language model with given hidden size."""
|
||||
"""Returns a language model with given hidden size."""
|
||||
|
||||
def language_model(X, y):
|
||||
inputs = learn.ops.one_hot_matrix(X, 256)
|
||||
inputs = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs)
|
||||
target = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, y)
|
||||
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
|
||||
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
|
||||
return learn.ops.sequence_classifier(output, target)
|
||||
|
||||
return language_model
|
||||
def language_model(X, y):
|
||||
inputs = learn.ops.one_hot_matrix(X, 256)
|
||||
inputs = tf.unpack(inputs, axis=1)
|
||||
target = tf.unpack(y, axis=1)
|
||||
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
|
||||
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
|
||||
return learn.ops.sequence_classifier(output, target)
|
||||
|
||||
return language_model
|
||||
|
||||
|
||||
### Training model.
|
||||
|
||||
estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
|
||||
n_classes=256,
|
||||
optimizer='Adam', learning_rate=0.01,
|
||||
steps=1000, batch_size=64, continue_training=True)
|
||||
estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
|
||||
n_classes=256, optimizer='Adam',
|
||||
learning_rate=0.01, steps=1000,
|
||||
batch_size=64, continue_training=True)
|
||||
|
||||
estimator.fit(X, y)
|
||||
|
@ -56,7 +56,7 @@ def rnn_model(x, y):
|
||||
|
||||
# Split into list of embedding per word, while removing doc length dim.
|
||||
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
|
||||
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
|
||||
word_list = tf.unpack(word_vectors, axis=1)
|
||||
|
||||
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
|
||||
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
|
||||
|
@ -41,7 +41,7 @@ def input_op_fn(x):
|
||||
embedding_size=EMBEDDING_SIZE, name='words')
|
||||
# Split into list of embedding per word, while removing doc length dim.
|
||||
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
|
||||
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
|
||||
word_list = tf.unpack(word_vectors, axis=1)
|
||||
return word_list
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ def char_rnn_model(x, y):
|
||||
"""Character level recurrent neural network model to predict classes."""
|
||||
y = tf.one_hot(y, 15, 1, 0)
|
||||
byte_list = learn.ops.one_hot_matrix(x, 256)
|
||||
byte_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list)
|
||||
byte_list = tf.unpack(byte_list, axis=1)
|
||||
|
||||
cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
|
||||
_, encoding = tf.nn.rnn(cell, byte_list, dtype=tf.float32)
|
||||
|
@ -46,47 +46,47 @@ print('Total words: %d' % n_words)
|
||||
EMBEDDING_SIZE = 50
|
||||
|
||||
def average_model(X, y):
|
||||
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
|
||||
embedding_size=EMBEDDING_SIZE, name='words')
|
||||
features = tf.reduce_max(word_vectors, reduction_indices=1)
|
||||
return learn.models.logistic_regression(features, y)
|
||||
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
|
||||
embedding_size=EMBEDDING_SIZE, name='words')
|
||||
features = tf.reduce_max(word_vectors, reduction_indices=1)
|
||||
return learn.models.logistic_regression(features, y)
|
||||
|
||||
def rnn_model(X, y):
|
||||
"""Recurrent neural network model to predict from sequence of words
|
||||
"""Recurrent neural network model to predict from sequence of words
|
||||
to a class."""
|
||||
# Convert indexes of words into embeddings.
|
||||
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
|
||||
# maps word indexes of the sequence into [batch_size, sequence_length,
|
||||
# EMBEDDING_SIZE].
|
||||
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
|
||||
embedding_size=EMBEDDING_SIZE, name='words')
|
||||
# Split into list of embedding per word, while removing doc length dim.
|
||||
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
|
||||
word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
|
||||
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
|
||||
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
|
||||
# Create an unrolled Recurrent Neural Networks to length of
|
||||
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
|
||||
_, encoding = tf.nn.rnn(cell, word_list, dtype=tf.float32)
|
||||
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
|
||||
# neural network of last step) and pass it as features for logistic
|
||||
# regression over output classes.
|
||||
return learn.models.logistic_regression(encoding, y)
|
||||
# Convert indexes of words into embeddings.
|
||||
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
|
||||
# maps word indexes of the sequence into [batch_size, sequence_length,
|
||||
# EMBEDDING_SIZE].
|
||||
word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
|
||||
embedding_size=EMBEDDING_SIZE, name='words')
|
||||
# Split into list of embedding per word, while removing doc length dim.
|
||||
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
|
||||
word_list = tf.unpack(word_vectors, axis=1)
|
||||
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
|
||||
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
|
||||
# Create an unrolled Recurrent Neural Networks to length of
|
||||
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
|
||||
_, encoding = tf.nn.rnn(cell, word_list, dtype=tf.float32)
|
||||
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
|
||||
# neural network of last step) and pass it as features for logistic
|
||||
# regression over output classes.
|
||||
return learn.models.logistic_regression(encoding, y)
|
||||
|
||||
model_path = '/tmp/skflow_examples/text_classification'
|
||||
if os.path.exists(model_path):
|
||||
classifier = learn.TensorFlowEstimator.restore(model_path)
|
||||
classifier = learn.TensorFlowEstimator.restore(model_path)
|
||||
else:
|
||||
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
|
||||
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
|
||||
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
|
||||
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
|
||||
|
||||
# Continuously train for 1000 steps
|
||||
while True:
|
||||
try:
|
||||
classifier.fit(X_train, y_train)
|
||||
except KeyboardInterrupt:
|
||||
classifier.save(model_path)
|
||||
break
|
||||
# Continuously train for 1000 steps
|
||||
while True:
|
||||
try:
|
||||
classifier.fit(X_train, y_train)
|
||||
except KeyboardInterrupt:
|
||||
classifier.save(model_path)
|
||||
break
|
||||
# Predict on test set
|
||||
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
|
||||
print('Accuracy: {0:f}'.format(score))
|
||||
|
@ -170,6 +170,8 @@ def train():
|
||||
else: # Record a summary
|
||||
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
|
||||
train_writer.add_summary(summary, i)
|
||||
train_writer.close()
|
||||
test_writer.close()
|
||||
|
||||
|
||||
def main(_):
|
||||
|
@ -3199,7 +3199,7 @@ dist.pmf(counts) # Shape [2]
|
||||
```
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, allow_nan=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
|
||||
|
||||
Initialize a batch of DirichletMultinomial distributions.
|
||||
|
||||
@ -3216,8 +3216,15 @@ Initialize a batch of DirichletMultinomial distributions.
|
||||
* <b>`allow_arbitrary_counts`</b>: Boolean. This represents whether the pmf/cdf
|
||||
allows for the `counts` tensor to be non-integral values.
|
||||
The pmf/cdf are functions that can be evaluated at non-integral values,
|
||||
but are only a distribution over non-negative integers.
|
||||
* <b>`strict`</b>: Not used (yet).
|
||||
but are only a distribution over non-negative integers. If `strict` is
|
||||
`False`, this assertion is turned off.
|
||||
* <b>`allow_nan`</b>: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
* <b>`strict`</b>: Whether to assert valid values for parameters `alpha` and `n`, and
|
||||
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
|
||||
guaranteed.
|
||||
* <b>`name`</b>: The name to prefix Ops created by this distribution class.
|
||||
|
||||
|
||||
@ -3233,6 +3240,13 @@ dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.allow_nan` {#DirichletMultinomial.allow_nan}
|
||||
|
||||
Boolean describing behavior when a stat is undefined for batch member.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.alpha` {#DirichletMultinomial.alpha}
|
||||
|
@ -40,15 +40,24 @@ Initializes a BaseEstimator instance.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -318,15 +327,24 @@ Constructs an Estimator instance.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -599,15 +617,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -938,15 +965,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -1320,15 +1356,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -1603,15 +1648,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -1859,15 +1913,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -2491,15 +2554,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -2856,15 +2928,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -3139,15 +3220,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -3395,15 +3485,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
@ -4283,15 +4382,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -97,15 +97,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -115,15 +115,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -1,4 +1,4 @@
|
||||
#### `tf.train.Server.create_local_server(start=True)` {#Server.create_local_server}
|
||||
#### `tf.train.Server.create_local_server(config=None, start=True)` {#Server.create_local_server}
|
||||
|
||||
Creates a new single-process cluster running on the local host.
|
||||
|
||||
@ -10,6 +10,8 @@ single-process cluster containing a single task in a job called
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server after
|
||||
creating it. Defaults to `True`.
|
||||
|
||||
|
@ -25,15 +25,24 @@ Initializes a BaseEstimator instance.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -67,7 +67,7 @@ dist.pmf(counts) # Shape [2]
|
||||
```
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.__init__(n, alpha, allow_arbitrary_counts=False, allow_nan=False, strict=True, name='DirichletMultinomial')` {#DirichletMultinomial.__init__}
|
||||
|
||||
Initialize a batch of DirichletMultinomial distributions.
|
||||
|
||||
@ -84,8 +84,15 @@ Initialize a batch of DirichletMultinomial distributions.
|
||||
* <b>`allow_arbitrary_counts`</b>: Boolean. This represents whether the pmf/cdf
|
||||
allows for the `counts` tensor to be non-integral values.
|
||||
The pmf/cdf are functions that can be evaluated at non-integral values,
|
||||
but are only a distribution over non-negative integers.
|
||||
* <b>`strict`</b>: Not used (yet).
|
||||
but are only a distribution over non-negative integers. If `strict` is
|
||||
`False`, this assertion is turned off.
|
||||
* <b>`allow_nan`</b>: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
* <b>`strict`</b>: Whether to assert valid values for parameters `alpha` and `n`, and
|
||||
`x` in `pmf` and `log_pmf`. If False, correct behavior is not
|
||||
guaranteed.
|
||||
* <b>`name`</b>: The name to prefix Ops created by this distribution class.
|
||||
|
||||
|
||||
@ -101,6 +108,13 @@ dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.allow_nan` {#DirichletMultinomial.allow_nan}
|
||||
|
||||
Boolean describing behavior when a stat is undefined for batch member.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.DirichletMultinomial.alpha` {#DirichletMultinomial.alpha}
|
||||
|
@ -42,15 +42,24 @@ Constructs an Estimator instance.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -9,7 +9,7 @@ communicate with any other server in the same cluster.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Server.__init__(server_or_cluster_def, job_name=None, task_index=None, protocol=None, start=True)` {#Server.__init__}
|
||||
#### `tf.train.Server.__init__(server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True)` {#Server.__init__}
|
||||
|
||||
Creates a new server with the given definition.
|
||||
|
||||
@ -32,6 +32,8 @@ override any information provided in `server_or_cluster_def`.
|
||||
* <b>`protocol`</b>: (Optional.) Specifies the protocol to be used by the server.
|
||||
Acceptable values include `"grpc"`. Defaults to the value in
|
||||
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
|
||||
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server
|
||||
after creating it. Defaults to `True`.
|
||||
|
||||
@ -43,7 +45,7 @@ override any information provided in `server_or_cluster_def`.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Server.create_local_server(start=True)` {#Server.create_local_server}
|
||||
#### `tf.train.Server.create_local_server(config=None, start=True)` {#Server.create_local_server}
|
||||
|
||||
Creates a new single-process cluster running on the local host.
|
||||
|
||||
@ -55,6 +57,8 @@ single-process cluster containing a single task in a job called
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`config`</b>: (Options.) A `tf.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
* <b>`start`</b>: (Optional.) Boolean, indicating whether to start the server after
|
||||
creating it. Defaults to `True`.
|
||||
|
||||
@ -84,6 +88,18 @@ with tf.Session(server.target):
|
||||
A string containing a session target for this server.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Server.server_def` {#Server.server_def}
|
||||
|
||||
Returns the `tf.train.ServerDef` for this server.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A `tf.train.ServerDef` prototocol buffer that describes the configuration
|
||||
of this server.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
|
@ -116,15 +116,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -4,7 +4,7 @@ Computes softmax activations.
|
||||
|
||||
For each batch `i` and class `j` we have
|
||||
|
||||
softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
|
||||
softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
|
||||
|
||||
##### Args:
|
||||
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -114,15 +114,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
@ -33,15 +33,24 @@ Returns weights of deep neural network part.
|
||||
|
||||
Evaluates given model with provided evaluation data.
|
||||
|
||||
Evaluates on the given input data. If `input_fn` is provided, that
|
||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||
`StopIteration`) after one epoch of the training data has been provided.
|
||||
|
||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||
only `steps` batches of size `batch_size` are processed.
|
||||
|
||||
The return value is a dict containing the metrics specified in `metrics`, as
|
||||
well as an entry `global_step` which contains the value of the global step
|
||||
for which this evaluation was performed.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`x`</b>: features.
|
||||
* <b>`y`</b>: targets.
|
||||
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||
`None`. If `steps` is `None`, the tensors returned by this should
|
||||
generally raise an end-of-input exception when all eval records have
|
||||
been returned (typically, 1 epoch over eval data).
|
||||
`None`.
|
||||
* <b>`feed_fn`</b>: Function creating a feed dict every time it is called. Called
|
||||
once per iteration.
|
||||
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user