Merge pull request #5387 from vrv/branch_138101249

Branch 138101249
This commit is contained in:
Vijay Vasudevan 2016-11-03 13:37:13 -07:00 committed by GitHub
commit 76b2c0630b
245 changed files with 11614 additions and 8855 deletions

View File

@ -62,8 +62,6 @@ cc_library(
# This define (mostly) guarantees we don't link any problematic
# code. We use it, but we do not rely on it, as evidenced above.
"EIGEN_MPL2_ONLY",
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
# eigen_initialize.cc file and alwayslink=1.
],
includes = ["."],
visibility = ["//visibility:public"],

View File

@ -105,6 +105,7 @@ filegroup(
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/integrate:all_files",
"//tensorflow/contrib/layers:all_files",
"//tensorflow/contrib/layers/kernels:all_files",
"//tensorflow/contrib/learn:all_files",
@ -148,7 +149,6 @@ filegroup(
"//tensorflow/examples/image_retraining:all_files",
"//tensorflow/examples/label_image:all_files",
"//tensorflow/examples/learn:all_files",
"//tensorflow/examples/skflow:all_files",
"//tensorflow/examples/tutorials/estimators:all_files",
"//tensorflow/examples/tutorials/mnist:all_files",
"//tensorflow/examples/tutorials/word2vec:all_files",

View File

@ -264,6 +264,36 @@ tf_cc_test(
],
)
cc_library(
name = "nn_grad",
srcs = ["gradients/nn_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "gradients_nn_grad_test",
srcs = ["gradients/nn_grad_test.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
":nn_grad",
":testutil",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [

View File

@ -110,20 +110,15 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
return Status::OK();
}
} // namespace
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape, const ops::Output& y,
const TensorShape& y_shape, T* max_error) {
Status ComputeGradientErrorInternal(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape,
const ops::Output& y,
const TensorShape& y_shape, Tensor* x_data,
T* max_error) {
const int64 x_size = x_shape.num_elements();
const int64 y_size = y_shape.num_elements();
// Initialize 'x_data' to random values.
Tensor x_data(x.type(), x_shape);
auto x_data_flat = x_data.flat<T>();
x_data_flat.setRandom();
// Initialize theoretical Jacobian to zeros.
Tensor jacobian_t(x.type(), {x_size, y_size});
auto jacobian_t_flat = jacobian_t.flat<T>();
@ -131,7 +126,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
// Compute theoretical Jacobian.
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
scope, x, x_shape, *x_data, y, y_shape, &jacobian_t));
// Initialize numeric Jacobian to zeros.
Tensor jacobian_n(x.type(), {x_size, y_size});
@ -140,7 +135,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
// Compute numeric Jacobian.
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
scope, x, x_shape, y, y_shape, 1e-3, x_data, &jacobian_n));
// Compute the maximum error between theoretical and numeric Jacobians.
*max_error = 0.0;
@ -154,10 +149,39 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
return Status::OK();
}
} // namespace
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape, const ops::Output& y,
const TensorShape& y_shape, T* max_error) {
// Initialize 'x_data' to random values.
Tensor x_data(x.type(), x_shape);
auto x_data_flat = x_data.flat<T>();
x_data_flat.setRandom();
// Compute gradient error.
return ComputeGradientErrorInternal(scope, x, x_shape, y, y_shape, &x_data,
max_error);
}
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const Tensor& x_init_value, const ops::Output& y,
const TensorShape& y_shape, T* max_error) {
// Initialize 'x_data' from 'x_init_value'.
Tensor x_data(x_init_value);
// Compute gradient error.
return ComputeGradientErrorInternal(scope, x, x_data.shape(), y, y_shape,
&x_data, max_error);
}
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
template Status ComputeGradientError<T>( \
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
const ops::Output& y, const TensorShape& y_shape, T* max_error)
const ops::Output& y, const TensorShape& y_shape, T* max_error); \
template Status ComputeGradientError<T>( \
const Scope& scope, const ops::Output& x, const Tensor& x_init_value, \
const ops::Output& y, const TensorShape& y_shape, T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float);
INSTANTIATE_GRAD_ERR_TYPE(double);

View File

@ -30,6 +30,12 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape, const ops::Output& y,
const TensorShape& y_shape, T* max_error);
// Overload of ComputeGradientError which takes an initial value for 'x'.
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const Tensor& x_init_value, const ops::Output& y,
const TensorShape& y_shape, T* max_error);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
namespace {
Status SoftmaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// Softmax gradient function.
// p = softmax(x) maps from [batch, n] to [batch, m]
// dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
// [ ... ... ]
// [dpm-1/dx0 ... dpm-1/dxn-1]
// dL/dx = dp/dx * dL/dy
//
// Using alternative formula:
// dL/dx = dL/dy * y - sum(dL/dy * y) * y
// = (dL/dy - sum(dL/dy * y)) * y
auto y = op.output(0);
auto dyy = Mul(scope, grad_inputs[0], y);
auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
auto sub = Sub(scope, grad_inputs[0], sum);
auto dx = Mul(scope, sub, y);
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
Status ReluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto dx = ReluGrad(scope, grad_inputs[0], op.input(0));
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
Status Relu6GradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto dx = Relu6Grad(scope, grad_inputs[0], op.input(0));
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
Status EluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto dx = EluGrad(scope, grad_inputs[0], op.output(0));
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,91 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
class NNGradTest : public ::testing::Test {
protected:
NNGradTest() : scope_(Scope::NewRootScope()) {}
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error));
EXPECT_LT(max_error, 1e-4);
}
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
EXPECT_LT(max_error, 1e-4);
}
Scope scope_;
};
TEST_F(NNGradTest, SoftmaxGrad) {
TensorShape shape({32, 10});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Softmax(scope_, x);
RunTest(x, shape, y, shape);
}
TEST_F(NNGradTest, ReluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Relu(scope_, x);
// Avoid input values where ReLU gradient is not well defined (around zero).
Tensor x_init_value = test::AsTensor<float>(
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
RunTest(x, x_init_value, y, shape);
}
TEST_F(NNGradTest, Relu6Grad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Relu6(scope_, x);
// Avoid input values where ReLU gradient is not well defined (around zero
// and six).
Tensor x_init_value = test::AsTensor<float>(
{-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9}, {5, 2});
RunTest(x, x_init_value, y, shape);
}
TEST_F(NNGradTest, EluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Elu(scope_, x);
Tensor x_init_value = test::AsTensor<float>(
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
RunTest(x, x_init_value, y, shape);
}
} // namespace
} // namespace tensorflow

View File

@ -23,6 +23,7 @@ py_library(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",

View File

@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import integrate
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import linear_optimizer

View File

@ -76,7 +76,7 @@ def build_split_apply_merge_model():
# REINFORCE forward step
route_selection = st.StochasticTensor(
distributions.Categorical, logits=logits)
distributions.Categorical(logits=logits))
# Accessing route_selection as a Tensor below forces a sample of
# the Categorical distribution based on its logits.

View File

@ -22,6 +22,7 @@ import tensorflow as tf
st = tf.contrib.bayesflow.stochastic_tensor
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
dists = tf.contrib.distributions
class StochasticGradientEstimatorsTest(tf.test.TestCase):
@ -31,7 +32,7 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
self._final_loss = tf.constant(3.2)
def _testScoreFunction(self, loss_fn, expected):
x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
x = st.StochasticTensor(dists.Bernoulli(p=self._p), loss_fn=loss_fn)
sf = x.loss(self._final_loss)
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
@ -62,8 +63,8 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
def testScoreFunctionWithMeanBaseline(self):
ema_decay = 0.8
num_steps = 6
x = st.BernoulliTensor(
p=self._p,
x = st.StochasticTensor(
dists.Bernoulli(p=self._p),
loss_fn=sge.get_score_function_with_baseline(
sge.get_mean_baseline(ema_decay)))
sf = x.loss(self._final_loss)
@ -98,12 +99,12 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
ema_decay = 0.8
x = st.BernoulliTensor(
p=self._p,
x = st.StochasticTensor(
dists.Bernoulli(p=self._p),
loss_fn=sge.get_score_function_with_baseline(
sge.get_mean_baseline(ema_decay)))
y = st.BernoulliTensor(
p=self._p,
y = st.StochasticTensor(
dists.Bernoulli(p=self._p),
loss_fn=sge.get_score_function_with_baseline(
sge.get_mean_baseline(ema_decay)))
sf_x = x.loss(self._final_loss)

View File

@ -39,9 +39,9 @@ class TestSurrogateLosses(tf.test.TestCase):
mu = [0.0, 0.1, 0.2]
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(
distributions.Normal, mu=prior, sigma=sigma)
distributions.Normal(mu=prior, sigma=sigma))
self.assertTrue(prior.distribution.is_reparameterized)
self.assertTrue(likelihood.distribution.is_reparameterized)
@ -77,10 +77,9 @@ class TestSurrogateLosses(tf.test.TestCase):
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
prior = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
likelihood = st.StochasticTensor(
NormalNotParam, mu=prior, sigma=sigma)
prior_2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
loss = tf.square(tf.identity(likelihood) - mu)
part_loss = tf.square(tf.identity(prior) - mu)
@ -155,9 +154,7 @@ class TestSurrogateLosses(tf.test.TestCase):
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
dt = st.StochasticTensor(NormalNotParam,
mu=mu,
sigma=sigma,
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
loss_fn=None)
self.assertEqual(None, dt.loss(tf.constant([2.0])))
@ -166,8 +163,8 @@ class TestSurrogateLosses(tf.test.TestCase):
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
dt1 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
dt2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
loss = tf.square(tf.identity(dt1)) + 10. + dt2
sl_all = sg.surrogate_loss([loss])
@ -186,8 +183,8 @@ class TestSurrogateLosses(tf.test.TestCase):
class StochasticDependenciesMapTest(tf.test.TestCase):
def testBuildsMapOfUpstreamNodes(self):
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt2 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
out1 = dt1.value() + 1.
out2 = dt2.value() + 2.
x = out1 + out2
@ -197,11 +194,11 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
self.assertEqual(dep_map[dt2], set([x, y]))
def testHandlesStackedStochasticNodes(self):
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
out1 = dt1.value() + 1.
dt2 = st.StochasticTensor(distributions.Normal, mu=out1, sigma=1.)
dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
x = dt2.value() + 2.
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
y = dt3.value() * 3.
dep_map = sg._stochastic_dependencies_map([x, y])
self.assertEqual(dep_map[dt1], set([x]))
@ -209,10 +206,10 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
self.assertEqual(dep_map[dt3], set([y]))
def testTraversesControlInputs(self):
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
logits = dt1.value() * 3.
dt2 = st.StochasticTensor(distributions.Bernoulli, logits=logits)
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
x = dt3.value()
y = tf.ones((2, 2)) * 4.
z = tf.ones((2, 2)) * 3.

View File

@ -35,19 +35,19 @@ class StochasticTensorTest(tf.test.TestCase):
sigma2 = tf.constant([0.1, 0.2, 0.3])
prior_default = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
prior_0 = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma,
distributions.Normal(mu=mu, sigma=sigma),
dist_value_type=st.SampleAndReshapeValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
with st.value_type(st.SampleAndReshapeValue()):
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
likelihood = st.StochasticTensor(
distributions.Normal, mu=prior, sigma=sigma2)
distributions.Normal(mu=prior, sigma=sigma2))
self.assertTrue(
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
@ -77,7 +77,7 @@ class StochasticTensorTest(tf.test.TestCase):
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.MeanValue()):
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
prior_mean = prior.mean()
@ -94,7 +94,8 @@ class StochasticTensorTest(tf.test.TestCase):
with st.value_type(st.SampleAndReshapeValue()):
prior_single = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
distributions.Normal(
mu=mu, sigma=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@ -104,7 +105,7 @@ class StochasticTensorTest(tf.test.TestCase):
with st.value_type(st.SampleAndReshapeValue(n=2)):
prior_double = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
distributions.Normal(mu=mu, sigma=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (4, 3))
@ -119,7 +120,7 @@ class StochasticTensorTest(tf.test.TestCase):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
prior_single_value = prior_single.value()
@ -130,7 +131,7 @@ class StochasticTensorTest(tf.test.TestCase):
with st.value_type(st.SampleValue(n=2)):
prior_double = st.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
distributions.Normal(mu=mu, sigma=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
@ -143,9 +144,9 @@ class StochasticTensorTest(tf.test.TestCase):
mu = [0.0, -1.0, 1.0]
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.MeanValue()):
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
entropy = prior.entropy()
deep_entropy = prior.entropy()
deep_entropy = prior.distribution.entropy()
expected_deep_entropy = distributions.Normal(
mu=mu, sigma=sigma).entropy()
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
@ -159,17 +160,15 @@ class StochasticTensorTest(tf.test.TestCase):
# With default
with st.value_type(st.MeanValue(stop_gradient=True)):
dt = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
loss = dt.loss([tf.constant(2.0)])
self.assertTrue(loss is not None)
self.assertAllClose(dt.distribution.log_prob(mu).eval() * 2.0,
loss.eval())
self.assertAllClose(
dt.distribution.log_prob(mu).eval() * 2.0, loss.eval())
# With passed-in loss_fn.
dt = st.StochasticTensor(
distributions.Normal,
mu=mu,
sigma=sigma,
distributions.Normal(mu=mu, sigma=sigma),
dist_value_type=st.MeanValue(stop_gradient=True),
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=tf.constant(8.0)))
@ -204,7 +203,7 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
sigma = tf.constant([1.1, 1.2, 1.3])
obs = tf.zeros((2, 3))
z = st.ObservedStochasticTensor(
distributions.Normal, mu=mu, sigma=sigma, value=obs)
distributions.Normal(mu=mu, sigma=sigma), value=obs)
[obs_val, z_val] = sess.run([obs, z.value()])
self.assertAllEqual(obs_val, z_val)
@ -216,13 +215,13 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
sigma = tf.placeholder(tf.float32)
obs = tf.placeholder(tf.float32)
z = st.ObservedStochasticTensor(
distributions.Normal, mu=mu, sigma=sigma, value=obs)
distributions.Normal(mu=mu, sigma=sigma), value=obs)
mu2 = tf.placeholder(tf.float32, shape=[None])
sigma2 = tf.placeholder(tf.float32, shape=[None])
obs2 = tf.placeholder(tf.float32, shape=[None, None])
z2 = st.ObservedStochasticTensor(
distributions.Normal, mu=mu2, sigma=sigma2, value=obs2)
distributions.Normal(mu=mu2, sigma=sigma2), value=obs2)
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [z, z2])
@ -230,27 +229,19 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
def testConstructionErrors(self):
mu = [0., 0.]
sigma = [1., 1.]
self.assertRaises(ValueError, st.ObservedStochasticTensor,
distributions.Normal, mu=mu, sigma=sigma,
value=tf.zeros((3,)))
self.assertRaises(ValueError, st.ObservedStochasticTensor,
distributions.Normal, mu=mu, sigma=sigma,
value=tf.zeros((3, 1)))
self.assertRaises(ValueError, st.ObservedStochasticTensor,
distributions.Normal, mu=mu, sigma=sigma,
value=tf.zeros((1, 2), dtype=tf.int32))
class AutomaticDistributionImportTest(tf.test.TestCase):
def testImportNormal(self):
self.assertTrue(hasattr(st, "NormalTensor"))
self.assertTrue(callable(st.NormalTensor))
norm = st.NormalTensor(mu=0.0, sigma=1.0)
self.assertEqual(type(norm).__name__, "NormalTensor")
self.assertTrue(isinstance(norm, st.NormalTensor))
self.assertTrue(isinstance(norm, st.StochasticTensor))
if __name__ == "__main__":
tf.test.main()
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(mu=mu, sigma=sigma),
value=tf.zeros((3,)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(mu=mu, sigma=sigma),
value=tf.zeros((3, 1)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(mu=mu, sigma=sigma),
value=tf.zeros(
(1, 2), dtype=tf.int32))

View File

@ -44,7 +44,7 @@ def mini_vae():
x = [[-6., 3., 6.], [-8., 4., 8.]]
prior = distributions.Normal(mu=0., sigma=1.)
variational = st.StochasticTensor(
distributions.Normal, mu=inference_net(x, 1), sigma=1.)
distributions.Normal(mu=inference_net(x, 1), sigma=1.))
vi.register_prior(variational, prior)
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
@ -101,7 +101,7 @@ class VariationalInferenceTest(tf.test.TestCase):
prior = distributions.Bernoulli(0.5)
variational = st.StochasticTensor(
NormalNoEntropy, mu=inference_net(x, 1), sigma=1.)
NormalNoEntropy(mu=inference_net(x, 1), sigma=1.))
vi.register_prior(variational, prior)
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)

View File

@ -44,7 +44,6 @@ from __future__ import print_function
import abc
import collections
import contextlib
import inspect
import threading
import six
@ -79,10 +78,6 @@ class BaseStochasticTensor(object):
def graph(self):
pass
@abc.abstractproperty
def input_dict(self):
pass
@abc.abstractmethod
def value(self, name=None):
pass
@ -120,6 +115,7 @@ class BaseStochasticTensor(object):
# pylint: disable=protected-access
ops.register_tensor_conversion_function(
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
# pylint: enable=protected-access
@ -223,8 +219,8 @@ class SampleAndReshapeValue(_StochasticValueType):
st_value = st.value()
assertEqual(st_value.get_shape(), (4, 3))
dt_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
assertEqual(dt_value_val.shape, (4, 3))
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
assertEqual(st_value_val.shape, (4, 3))
```
"""
@ -312,17 +308,16 @@ class StochasticTensor(BaseStochasticTensor):
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
def __init__(self,
dist_cls,
name=None,
dist,
name="StochasticTensor",
dist_value_type=None,
loss_fn=sge.score_function,
**dist_args):
loss_fn=sge.score_function):
"""Construct a `StochasticTensor`.
`StochasticTensor` will instantiate a distribution from `dist_cls` and
`dist_args` and its `value` method will return the same value each time
it is called. What `value` is returned is controlled by the
`dist_value_type` (defaults to `SampleAndReshapeValue`).
`StochasticTensor` is backed by the `dist` distribution and its `value`
method will return the same value each time it is called. What `value` is
returned is controlled by the `dist_value_type` (defaults to
`SampleAndReshapeValue`).
Some distributions' sample functions are not differentiable (e.g. a sample
from a discrete distribution like a Bernoulli) and so to differentiate
@ -338,28 +333,25 @@ class StochasticTensor(BaseStochasticTensor):
`MeanValueType` or if `loss_fn=None`.
Args:
dist_cls: a `Distribution` class.
dist: an instance of `Distribution`.
name: a name for this `StochasticTensor` and its ops.
dist_value_type: a `_StochasticValueType`, which will determine what the
`value` of this `StochasticTensor` will be. If not provided, the
value type set with the `value_type` context manager will be used.
loss_fn: callable that takes `(st, st.value(), influenced_loss)`, where
loss_fn: callable that takes
`(st, st.value(), influenced_loss)`, where
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
default, `loss_fn` is the `score_function`, or more precisely, the
integral of the score function, such that when the gradient is taken,
the score function results. See the `stochastic_gradient_estimators`
module for additional loss functions and baselines.
**dist_args: keyword arguments to be passed through to `dist_cls` on
construction.
Raises:
TypeError: if `dist_cls` is not a `Distribution`.
TypeError: if `dist` is not an instance of `Distribution`.
TypeError: if `loss_fn` is not `callable`.
"""
if not issubclass(dist_cls, distributions.Distribution):
raise TypeError("dist_cls must be a subclass of Distribution")
self._dist_cls = dist_cls
self._dist_args = dist_args
if not isinstance(dist, distributions.Distribution):
raise TypeError("dist must be an instance of Distribution")
if dist_value_type is None:
try:
self._value_type = get_current_value_type()
@ -371,24 +363,17 @@ class StochasticTensor(BaseStochasticTensor):
with value_type(dist_value_type):
self._value_type = get_current_value_type()
self._value_type.declare_inputs(self, dist_args)
if loss_fn is not None and not callable(loss_fn):
raise TypeError("loss_fn must be callable")
self._loss_fn = loss_fn
with ops.name_scope(name, "StochasticTensor",
dist_args.values()) as scope:
with ops.name_scope(name) as scope:
self._name = scope
self._dist = dist_cls(**dist_args)
self._dist = dist
self._value = self._create_value()
super(StochasticTensor, self).__init__()
@property
def input_dict(self):
return self._dist_args
@property
def value_type(self):
return self._value_type
@ -397,9 +382,6 @@ class StochasticTensor(BaseStochasticTensor):
def distribution(self):
return self._dist
def clone(self, name=None, **dist_args):
return StochasticTensor(self._dist_cls, name=name, **dist_args)
def _create_value(self):
"""Create the value Tensor based on the value type, store as self._value."""
@ -494,33 +476,28 @@ class ObservedStochasticTensor(StochasticTensor):
"""A StochasticTensor with an observed value."""
# pylint: disable=super-init-not-called
def __init__(self, dist_cls, value, name=None, **dist_args):
def __init__(self, dist, value, name=None):
"""Construct an `ObservedStochasticTensor`.
`ObservedStochasticTensor` will instantiate a distribution from `dist_cls`
and `dist_args` but use the provided value instead of sampling from the
distribution. The provided value argument must be appropriately shaped
to have come from the constructed distribution.
`ObservedStochasticTensor` is backed by distribution `dist` and uses the
provided value instead of using the current value type to draw a value from
the distribution. The provided value argument must be appropriately shaped
to have come from the distribution.
Args:
dist_cls: a `Distribution` class.
dist: an instance of `Distribution`.
value: a Tensor containing the observed value
name: a name for this `ObservedStochasticTensor` and its ops.
**dist_args: keyword arguments to be passed through to `dist_cls` on
construction.
Raises:
TypeError: if `dist_cls` is not a `Distribution`.
TypeError: if `dist` is not an instance of `Distribution`.
ValueError: if `value` is not compatible with the distribution.
"""
if not issubclass(dist_cls, distributions.Distribution):
raise TypeError("dist_cls must be a subclass of Distribution")
self._dist_cls = dist_cls
self._dist_args = dist_args
with ops.name_scope(name, "ObservedStochasticTensor",
list(dist_args.values()) + [value]) as scope:
if not isinstance(dist, distributions.Distribution):
raise TypeError("dist must be an instance of Distribution")
with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope:
self._name = scope
self._dist = dist_cls(**dist_args)
self._dist = dist
dist_shape = self._dist.get_batch_shape().concatenate(
self._dist.get_event_shape())
value = ops.convert_to_tensor(value)
@ -538,7 +515,7 @@ class ObservedStochasticTensor(StochasticTensor):
"sample from the distribution %s." % (value_shape, dist_shape))
if value.dtype != self._dist.dtype:
raise ValueError("Type of observed value (%s) does not match type of "
"distribuiton (%s)." % (value.dtype, self._dist.dtype))
"distribution (%s)." % (value.dtype, self._dist.dtype))
self._value = array_ops.identity(value)
# pylint: disable=non-parent-init-called
BaseStochasticTensor.__init__(self)
@ -557,39 +534,3 @@ __all__ = [
"value_type",
"get_current_value_type",
]
_globals = globals()
# pylint: disable=redefined-builtin
__doc__ += "\n\n## Automatically Generated StochasticTensors\n\n"
# pylint: enable=redefined-builtin
for _name in sorted(dir(distributions)):
_candidate = getattr(distributions, _name)
if (inspect.isclass(_candidate)
and _candidate != distributions.Distribution
and issubclass(_candidate, distributions.Distribution)):
_local_name = "%sTensor" % _name
class _WrapperTensor(StochasticTensor):
_my_candidate = _candidate
def __init__(self, name=None, dist_value_type=None,
loss_fn=sge.score_function, **dist_args):
StochasticTensor.__init__(
self,
dist_cls=self._my_candidate,
name=name,
dist_value_type=dist_value_type,
loss_fn=loss_fn, **dist_args)
_WrapperTensor.__name__ = _local_name
_WrapperTensor.__doc__ = (
"`%s` is a `StochasticTensor` backed by the distribution `%s`."""
% (_local_name, _name))
_globals[_local_name] = _WrapperTensor
del _WrapperTensor
del _candidate
__all__.append(_local_name)
__doc__ += "@@%s\n" % _local_name
del _local_name

View File

@ -126,7 +126,7 @@ def get_stochastic_variable(getter,
dist_kwargs = dist_kwargs or {}
dist_kwargs.update(params)
sample = st.StochasticTensor(dist_cls, **dist_kwargs)
sample = st.StochasticTensor(dist_cls(**dist_kwargs))
if prior is not None:
if callable(prior):

View File

@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
def testCorrectlyMakesNoBatchLowerTril(self):
with self.test_session():
x = np.arange(9)
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
expected = np.array(
[[0., 0., 0.],
[1., 2., 0.],
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
actual = distribution_util.fill_lower_triangular(x)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
self.assertAllEqual(
np.concatenate([np.ones(6, dtype=np.float32),
np.zeros(3, dtype=np.float32)]),
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
def testCorrectlyMakesBatchLowerTril(self):
with self.test_session():

View File

@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
"""
with ops.name_scope(name, values=(x,)):
x = ops.convert_to_tensor(x, name="x")
ndims = x.get_shape().ndims
if ndims is not None and x.get_shape()[-1].value is not None:
if (x.get_shape().ndims is not None and
x.get_shape()[-1].value is not None):
d = x.get_shape()[-1].value
# d = n^2/2 + n/2 implies n is:
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
final_shape = x.get_shape()[:-1].concatenate(
tensor_shape.TensorShape([n, n]))
else:
ndims = array_ops.rank(x)
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
# d = n^2/2 + n/2 implies n is:
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
array_ops.tile([tril_ids], [m, 1])])
idx = array_ops.transpose(idx, [1, 2, 0])
y = array_ops.gather_nd(y, idx)
if x.get_shape().ndims == 1:
# Prefer using gather because it has a gradient.
# We wrap the result in a list so downstream logic "just works."
y = [array_ops.gather(y[0, :], tril_ids)]
else:
y = array_ops.gather_nd(y, idx)
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
y.set_shape(y.get_shape().merge_with(final_shape))

View File

@ -571,9 +571,8 @@ class WALSModel(object):
extras = size % num_shards
assignments = tf.maximum(ids // (ids_per_shard + 1),
(ids - extras) // ids_per_shard)
new_ids = tf.select(assignments < extras,
ids % (ids_per_shard + 1),
(ids - extras) % ids_per_shard)
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
(ids - extras) % ids_per_shard)
return assignments, new_ids
return func

View File

@ -36,7 +36,7 @@ class LocalVariableTest(tf.test.TestCase):
variables = tf.local_variables()
self.assertEquals(2, len(variables))
self.assertRaises(tf.OpError, sess.run, variables)
tf.initialize_variables(variables).run()
tf.variables_initializer(variables).run()
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testLocalVariableNameAndShape(self):
@ -51,7 +51,7 @@ class LocalVariableTest(tf.test.TestCase):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.local_variable(0)
self.assertFalse(a in tf.all_variables())
self.assertFalse(a in tf.global_variables())
self.assertTrue(a in tf.local_variables())
def testLocalVariableNotInVariablesToRestore(self):
@ -82,7 +82,7 @@ class LocalVariableTest(tf.test.TestCase):
def testInitializedVariableValue(self):
with self.test_session() as sess:
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
sess.run(tf.initialize_local_variables())
sess.run(tf.local_variables_initializer())
self.assertAllEqual(a.eval(), [0]*5)
@ -439,7 +439,7 @@ class ModelVariablesTest(tf.test.TestCase):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
self.assertTrue(a in tf.all_variables())
self.assertTrue(a in tf.global_variables())
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
self.assertFalse(a in tf.local_variables())
@ -474,7 +474,7 @@ class ModelVariablesTest(tf.test.TestCase):
with self.test_session() as sess:
a = tf.contrib.framework.model_variable(
'a', [5], initializer=tf.ones_initializer)
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
self.assertAllEqual(a.eval(), [1]*5)
def testDeviceFn(self):
@ -667,7 +667,7 @@ class AssignFromValuesTest(tf.test.TestCase):
var_names_to_values)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
sess.run(assign_op, feed_dict)
@ -697,7 +697,7 @@ class AssignFromValuesTest(tf.test.TestCase):
var_names_to_values)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
sess.run(assign_op, feed_dict)
@ -725,7 +725,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)
@ -754,7 +754,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)
@ -786,7 +786,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
var_value = var_names_to_values[var_name]
var_list.append(tf.Variable(var_value, name=var_name))
saver = tf.train.Saver(var_list)
init_op = tf.initialize_variables(var_list)
init_op = tf.variables_initializer(var_list)
sess.run(init_op)
# Save the initialized values in the file at 'checkpoint_dir'
return saver.save(sess, checkpoint_dir, global_step=global_step)
@ -808,7 +808,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
model_path, vars_to_restore)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
sess.run(op, feed_dict)
@ -859,7 +859,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
vars_to_restore)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
sess.run(op, feed_dict)
@ -890,7 +890,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
var_value = var_names_to_values[var_name]
var_list.append(tf.Variable(var_value, name=var_name))
saver = tf.train.Saver(var_list)
init_op = tf.initialize_variables(var_list)
init_op = tf.variables_initializer(var_list)
sess.run(init_op)
# Save the initialized values in the file at 'checkpoint_dir'
return saver.save(sess, checkpoint_dir, global_step=global_step)
@ -912,7 +912,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
model_path, vars_to_restore)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)
@ -938,7 +938,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
model_path, vars_to_restore)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
with self.assertRaises(tf.errors.InvalidArgumentError):
@ -961,7 +961,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
model_path, vars_to_restore, reshape_variables=True)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)
@ -989,7 +989,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
vars_to_restore)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
with self.assertRaises(tf.errors.NotFoundError):
@ -1015,7 +1015,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
ignore_missing_vars=True)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)
@ -1044,7 +1044,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
ignore_missing_vars=True)
# Initialize the variables.
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
# Perform the assignment.
init_fn(sess)

View File

@ -0,0 +1,38 @@
# Description:
# Integration and ODE solvers for TensorFlow.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
py_library(
name = "integrate_py",
srcs = [
"__init__.py",
"python/ops/odes.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "odes_test",
srcs = ["python/ops/odes_test.py"],
srcs_version = "PY2AND3",
deps = [
":integrate_py",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,9 @@
# Integration and ODE solvers for TensorFlow
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
contains a single function, `odeint`, for integrating ordinary differential
equations.
Maintainers:
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
- Marc Coram (mcoram@google.com, github.com/mcoram)

View File

@ -0,0 +1,64 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration and ODE solvers for TensorFlow.
## Example: Lorenz attractor
We can use `odeint` to solve the
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
differential equations, a prototypical example of chaotic dynamics:
```python
rho = 28.0
sigma = 10.0
beta = 8.0/3.0
def lorenz_equation(state, t):
x, y, z = tf.unpack(state)
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return tf.pack([dx, dy, dz])
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
t = np.linspace(0, 50, num=5000)
tensor_state, tensor_info = tf.contrib.integrate.odeint(
lorenz_equation, init_state, t, full_output=True)
sess = tf.Session()
state, info = sess.run([tensor_state, tensor_info])
x, y, z = state.T
plt.plot(x, z)
```
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
</div>
## Ops
@@odeint
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.integrate.python.ops.odes import *
from tensorflow.python.util.all_util import make_all
__all__ = make_all(__name__)

View File

@ -0,0 +1,503 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ODE solvers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
_ButcherTableau = collections.namedtuple(
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
# Parameters from Shampine (1986), section 4.
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
beta=[[1/5],
[3/40, 9/40],
[44/45, -56/15, 32/9],
[19372/6561, -25360/2187, 64448/6561, -212/729],
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
c_error=[1951/21600 - 35/384,
0,
22642/50085 - 500/1113,
451/720 - 125/192,
-12231/42400 - -2187/6784,
649/6300 - 11/84,
1/60],
)
def _possibly_nonzero(x):
return isinstance(x, ops.Tensor) or x != 0
def _scaled_dot_product(scale, xs, ys, name=None):
"""Calculate a scaled, vector inner product between lists of Tensors."""
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
# Some of the parameters in our Butcher tableau include zeros. Using
# _possibly_nonzero lets us avoid wasted computation.
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
if _possibly_nonzero(x) or _possibly_nonzero(y)],
name=scope)
def _dot_product(xs, ys, name=None):
"""Calculate the vector inner product between two lists of Tensors."""
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
name=None):
"""Take an arbitrary Runge-Kutta step and estimate error.
Args:
func: Function to evaluate like `func(y, t)` to compute the time derivative
of `y`.
y0: Tensor initial value for the state.
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
t0: float64 scalar Tensor giving the initial time.
dt: float64 scalar Tensor giving the size of the desired time step.
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
step.
name: optional name for the operation.
Returns:
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
calculating these terms.
"""
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
y0 = ops.convert_to_tensor(y0, name='y0')
f0 = ops.convert_to_tensor(f0, name='f0')
t0 = ops.convert_to_tensor(t0, name='t0')
dt = ops.convert_to_tensor(dt, name='dt')
dt_cast = math_ops.cast(dt, y0.dtype)
k = [f0]
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
ti = t0 + alpha_i * dt
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
k.append(func(yi, ti))
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
# This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
y1 = array_ops.identity(yi, name='%s/y1' % scope)
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
name='%s/y1_error' % scope)
return (y1, f1, y1_error, k)
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
"""Fit coefficients for 4th order polynomial interpolation.
Args:
y0: function value at the start of the interval.
y1: function value at the end of the interval.
y_mid: function value at the mid-point of the interval.
f0: derivative value at the start of the interval.
f1: derivative value at the end of the interval.
dt: width of the interval.
Returns:
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
between 0 (start of interval) and 1 (end of interval).
"""
# a, b, c, d, e = sympy.symbols('a b c d e')
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
# sympy.solve([p.subs(x, 0) - y0,
# p.subs(x, 1 / 2) - y_mid,
# p.subs(x, 1) - y1,
# (p.diff(x) / dt).subs(x, 0) - f0,
# (p.diff(x) / dt).subs(x, 1) - f1],
# [a, b, c, d, e])
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
# d: dt*f0,
# e: y0}
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
d = dt * f0
e = y0
return [a, b, c, d, e]
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
with ops.name_scope('interp_fit_rk'):
dt = math_ops.cast(dt, y0.dtype)
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
f0 = k[0]
f1 = k[-1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
def _interp_evaluate(coefficients, t0, t1, t):
"""Evaluate polynomial interpolation at the given time point.
Args:
coefficients: list of Tensor coefficients as created by `interp_fit`.
t0: scalar float64 Tensor giving the start of the interval.
t1: scalar float64 Tensor giving the end of the interval.
t: scalar float64 Tensor giving the desired interpolation point.
Returns:
Polynomial interpolation of the coefficients at time `t`.
"""
with ops.name_scope('interp_evaluate'):
t0 = ops.convert_to_tensor(t0)
t1 = ops.convert_to_tensor(t1)
t = ops.convert_to_tensor(t)
dtype = coefficients[0].dtype
assert_op = control_flow_ops.Assert(
(t0 <= t) & (t <= t1),
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
with ops.control_dependencies([assert_op]):
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
xs = [constant_op.constant(1, dtype), x]
for _ in range(2, len(coefficients)):
xs.append(xs[-1] * x)
return _dot_product(coefficients, reversed(xs))
def _optimal_step_size(last_step,
error_ratio,
safety=0.9,
ifactor=10.0,
dfactor=0.2,
order=5,
name=None):
"""Calculate the optimal size for the next Runge-Kutta step."""
with ops.name_scope(
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
exponent = math_ops.cast(1 / order, last_step.dtype)
# this looks more complex than necessary, but importantly it keeps
# error_ratio in the numerator so we can't divide by zero:
factor = math_ops.maximum(
1 / ifactor,
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
return math_ops.div(last_step, factor, name=scope)
def _abs_square(x):
if x.dtype.is_complex:
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
else:
return math_ops.square(x)
def _ta_append(tensor_array, value):
"""Append a value to the end of a tf.TensorArray."""
return tensor_array.write(tensor_array.size(), value)
class _RungeKuttaState(collections.namedtuple(
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
"""Saved state of the Runge Kutta solver.
Attributes:
y1: Tensor giving the function value at the end of the last time step.
f1: Tensor giving derivative at the end of the last time step.
t0: scalar float64 Tensor giving start of the last time step.
t1: scalar float64 Tensor giving end of the last time step.
dt: scalar float64 Tensor giving the size for the next time step.
interp_coef: list of Tensors giving coefficients for polynomial
interpolation between `t0` and `t1`.
"""
class _History(collections.namedtuple(
'_History', 'integrate_points, error_ratio')):
"""Saved integration history for use in `info_dict`.
Attributes:
integrate_points: tf.TensorArray storing integrating time points.
error_ratio: tf.TensorArray storing computed error ratios at each
integration step.
"""
def _dopri5(func,
y0,
t,
rtol,
atol,
full_output=False,
first_step=None,
safety=0.9,
ifactor=10.0,
dfactor=0.2,
max_num_steps=1000,
name=None):
"""Solve an ODE for `odeint` using method='dopri5'."""
if first_step is None:
# at some point, we might want to switch to picking the step size
# automatically
first_step = 1.0
with ops.name_scope(
name, 'dopri5',
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
name='first_step')
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
name='max_num_steps')
def adaptive_runge_kutta_step(rk_state, history, n_steps):
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
y0, f0, _, t0, dt, interp_coeff = rk_state
with ops.name_scope('assertions'):
check_underflow = control_flow_ops.Assert(
t0 + dt > t0, ['underflow in dt', dt])
check_max_num_steps = control_flow_ops.Assert(
n_steps < max_num_steps, ['max_num_steps exceeded'])
check_numerics = control_flow_ops.Assert(
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
['non-finite values in state `y`', y0])
with ops.control_dependencies(
[check_underflow, check_max_num_steps, check_numerics]):
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
with ops.name_scope('error_ratio'):
# We use the same approach as the dopri5 fortran code.
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
# Could also use reduce_maximum here.
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
accept_step = error_ratio <= 1
with ops.name_scope('update/rk_state'):
# If we don't accept the step, the _RungeKuttaState will be useless
# (covering a time-interval of size 0), but that's OK, because in such
# cases we always immediately take another Runge-Kutta step.
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
interp_coeff = control_flow_ops.cond(
accept_step,
lambda: _interp_fit_rk(y0, y1, k, dt),
lambda: interp_coeff)
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
rk_state = _RungeKuttaState(
y_next, f_next, t0, t_next, dt_next, interp_coeff)
with ops.name_scope('update/history'):
history = _History(_ta_append(history.integrate_points, t0 + dt),
_ta_append(history.error_ratio, error_ratio))
return rk_state, history, n_steps + 1
def interpolate(solution, history, rk_state, i):
"""Interpolate through the next time point, integrating as necessary."""
with ops.name_scope('interpolate'):
rk_state, history, _ = control_flow_ops.while_loop(
lambda rk_state, *_: t[i] > rk_state.t1,
adaptive_runge_kutta_step,
(rk_state, history, 0),
name='integrate_loop')
y = _interp_evaluate(
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
solution = solution.write(i, y)
return solution, history, rk_state, i + 1
assert_increasing = control_flow_ops.Assert(
math_ops.reduce_all(t[1:] > t[:-1]),
['`t` must be monotonic increasing'])
with ops.control_dependencies([assert_increasing]):
num_times = array_ops.size(t)
solution = tensor_array_ops.TensorArray(
y0.dtype, size=num_times).write(0, y0)
history = _History(
integrate_points=tensor_array_ops.TensorArray(
t.dtype, size=0, dynamic_size=True),
error_ratio=tensor_array_ops.TensorArray(
rtol.dtype, size=0, dynamic_size=True))
rk_state = _RungeKuttaState(
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
solution, history, _, _ = control_flow_ops.while_loop(
lambda _, __, ___, i: i < num_times,
interpolate,
(solution, history, rk_state, 1),
name='interpolate_loop')
y = solution.pack(name=scope)
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
if not full_output:
return y
else:
integrate_points = history.integrate_points.pack()
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
'integrate_points': integrate_points,
'error_ratio': history.error_ratio.pack()}
return (y, info_dict)
def odeint(func,
y0,
t,
rtol=1e-6,
atol=1e-12,
method=None,
options=None,
full_output=False,
name=None):
"""Integrate a system of ordinary differential equations.
Solves the initial value problem for a non-stiff system of first order ode-s:
```
dy/dt = func(y, t), y(t[0]) = y0
```
where y is a Tensor of any shape.
For example:
```
# solve `dy/dt = -y`, corresponding to exponential decay
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
=> [1, exp(-1), exp(-2)]
```
Output dtypes and numerical precision are based on the dtypes of the inputs
`y0` and `t`.
Currently, implements 5th order Runge-Kutta with adaptive step size control
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
method of `scipy.integrate.ode` and MATLAB's `ode45`.
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
doi:10.2307/2008219
Args:
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
`t` into a Tensor of state derivatives with respect to time.
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
have any floating point or complex dtype.
t: 1-D Tensor holding a sequence of time points for which to solve for
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype. If not provided as a Tensor, converted to a Tensor with
float64 dtype.
rtol: optional float64 Tensor specifying an upper bound on relative error,
per element of `y`.
atol: optional float64 Tensor specifying an upper bound on absolute error,
per element of `y`.
method: optional string indicating the integration method to use. Currently,
the only valid option is `'dopri5'`.
options: optional dict of configuring options for the indicated integration
method. Can only be provided if a `method` is explicitly set. For
`'dopri5'`, valid options include:
* first_step: an initial guess for the size of the first integration
(current default: 1.0, but may later be changed to use heuristics based
on the gradient).
* safety: safety factor for adaptive step control, generally a constant
in the range 0.8-1 (default: 0.9).
* ifactor: maximum factor by which the adaptive step may be increased
(default: 10.0).
* dfactor: maximum factor by which the adpative step may be decreased
(default: 0.2).
* max_num_steps: integer maximum number of integrate steps between time
points in `t` (default: 1000).
full_output: optional boolean. If True, `odeint` returns a tuple
`(y, info_dict)` describing the integration process.
name: Optional name for this operation.
Returns:
y: (N+1)-D tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
`t`, with the initial value `y0` being the first element along the first
dimension.
info_dict: only if `full_output == True`. A dict with the following values:
* num_func_evals: integer Tensor counting the number of function
evaluations.
* integrate_points: 1D float64 Tensor with the upper bound of each
integration time step.
* error_ratio: 1D float Tensor with the estimated ratio of the integration
error to the error tolerance at each integration step. An ratio greater
than 1 corresponds to rejected steps.
Raises:
ValueError: if an invalid `method` is provided.
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
an invalid dtype.
"""
if method is not None and method != 'dopri5':
raise ValueError('invalid method: %r' % method)
if options is None:
options = {}
elif method is None:
raise ValueError('cannot supply `options` without specifying `method`')
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
# arbitrarily nested tuple. This will help performance and usability by
# avoiding the need to pack/unpack in user functions.
y0 = ops.convert_to_tensor(y0, name='y0')
if not (y0.dtype.is_floating or y0.dtype.is_complex):
raise TypeError('`y0` must have a floating point or complex floating '
'point dtype')
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
if not t.dtype.is_floating:
raise TypeError('`t` must have a floating point dtype')
error_dtype = abs(y0).dtype
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
return _dopri5(func, y0, t,
rtol=rtol,
atol=atol,
full_output=full_output,
name=scope,
**options)

View File

@ -0,0 +1,232 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ODE solvers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.integrate.python.ops import odes
class OdeIntTest(tf.test.TestCase):
def setUp(self):
super(OdeIntTest, self).setUp()
# simple defaults (solution is a sin-wave)
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
self.func = lambda y, t: tf.matmul(matrix, y)
self.y0 = np.array([[1.0], [0.0]])
def test_odeint_exp(self):
# Test odeint by an exponential function:
# dy / dt = y, y(0) = 1.0.
# Its analytical solution is y = exp(t).
func = lambda y, t: y
y0 = tf.constant(1.0, dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, y0, t)
self.assertIn('odeint', y_solved.name)
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(t)
self.assertAllClose(y_true, y_solved)
def test_odeint_complex(self):
# Test a complex, linear ODE:
# dy / dt = k * y, y(0) = 1.0.
# Its analytical solution is y = exp(k * t).
k = 1j - 0.1
func = lambda y, t: k * y
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(k * t)
self.assertAllClose(y_true, y_solved)
def test_odeint_riccati(self):
# The Ricatti equation is:
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
func = lambda t, y: (y - t)**2 + 1.0
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = 1.0 / (2.0 - t) + t
self.assertAllClose(y_true, y_solved)
def test_odeint_2d_linear(self):
# Solve the 2D linear differential equation:
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
# y1(0) = 0.0,
# y2(0) = 1.0.
# Its analytical solution is
# y1 = sin(4.0 * t) * exp(3.0 * t),
# y2 = cos(4.0 * t) * exp(3.0 * t).
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
func = lambda y, t: tf.matmul(matrix, y)
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, y0, t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.zeros((len(t), 2, 1))
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
self.assertAllClose(y_true, y_solved, atol=1e-5)
def test_odeint_higher_rank(self):
func = lambda y, t: y
y0 = tf.constant(1.0, dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
for shape in [(), (1,), (1, 1)]:
expected_shape = (len(t),) + shape
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
self.assertEquals(y_solved.shape, expected_shape)
def test_odeint_all_dtypes(self):
func = lambda y, t: y
t = np.linspace(0.0, 1.0, 11)
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
for t_dtype in [tf.float32, tf.float64]:
y0 = tf.cast(1.0, y0_dtype)
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
expected = np.asarray(np.exp(t))
self.assertAllClose(y_solved, expected, rtol=1e-5)
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
def test_odeint_required_dtypes(self):
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
def test_odeint_runtime_errors(self):
with self.assertRaisesRegexp(
ValueError, 'cannot supply `options` without'):
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
options={'first_step': 1.0})
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
options={'max_num_steps': 0})
with self.test_session() as sess:
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError, 'max_num_steps'):
sess.run(y)
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
with self.test_session() as sess:
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError, 'monotonic increasing'):
sess.run(y)
def test_odeint_different_times(self):
# integrate steps should be independent of interpolation times
times0 = np.linspace(0, 10, num=11, dtype=float)
times1 = np.linspace(0, 10, num=101, dtype=float)
with self.test_session() as sess:
y_solved_0, info_0 = sess.run(
tf.contrib.integrate.odeint(
self.func, self.y0, times0, full_output=True))
y_solved_1, info_1 = sess.run(
tf.contrib.integrate.odeint(
self.func, self.y0, times1, full_output=True))
self.assertAllClose(y_solved_0, y_solved_1[::10])
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
def test_odeint_5th_order_accuracy(self):
t = [0, 20]
kwargs = dict(full_output=True,
method='dopri5',
options=dict(max_num_steps=2000))
with self.test_session() as sess:
_, info_0 = sess.run(tf.contrib.integrate.odeint(
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
_, info_1 = sess.run(tf.contrib.integrate.odeint(
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
float(info_1['integrate_points'].size),
rtol=0.01)
class StepSizeTest(tf.test.TestCase):
def test_error_ratio_one(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(1.0))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.9)
def test_ifactor(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(0.0))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 10.0)
def test_dfactor(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(1e6))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.2)
class InterpolationTest(tf.test.TestCase):
def test_5th_order_polynomial(self):
# this should be an exact fit
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
coeffs = odes._interp_fit(
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
times = np.linspace(0, 10, dtype=np.float32)
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
for t in times])
y_expected = f(times)
with self.test_session() as sess:
y_actual = sess.run(y_fit)
self.assertAllClose(y_expected, y_actual)
# attempt interpolation outside bounds
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(y_invalid)
if __name__ == '__main__':
tf.test.main()

View File

@ -258,10 +258,11 @@ def optimize_loss(loss,
grad_values = gradient
if grad_values is not None:
var_name = variable.name.replace(":", "_")
if "gradients" in summaries:
summary.histogram("gradients/" + variable.name, grad_values)
summary.histogram("gradients/%s" % var_name, grad_values)
if "gradient_norm" in summaries:
summary.scalar("gradient_norm/" + variable.name,
summary.scalar("gradient_norm/%s" % var_name,
clip_ops.global_norm([grad_values]))
if clip_gradients is not None and "gradient_norm" in summaries:

View File

@ -291,7 +291,9 @@ py_test(
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:test_ops",
],
)

View File

@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig

View File

@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import composable_model
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
@ -747,13 +748,16 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
Numpy array of predicted classes (or an iterable of predicted classes if
as_iterable is True).
"""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[head_lib.PredictionKey.CLASSES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.CLASSES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
return preds[head_lib.PredictionKey.CLASSES].reshape(-1)
return _as_iterable(preds, output=key)
return preds[key].reshape(-1)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -775,20 +779,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
Numpy array of predicted probabilities (or an iterable of predicted
probabilities if as_iterable is True).
"""
key = prediction_key.PredictionKey.PROBABILITIES
preds = self._estimator.predict(
x=x, input_fn=input_fn,
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[head_lib.PredictionKey.PROBABILITIES],
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
return preds[head_lib.PredictionKey.PROBABILITIES]
return _as_iterable(preds, output=key)
return preds[key]
def _get_predict_ops(self, features):
"""See `Estimator` class."""
# pylint: disable=protected-access
return self._estimator._get_predict_ops(features)[
head_lib.PredictionKey.PROBABILITIES]
prediction_key.PredictionKey.PROBABILITIES]
def get_variable_names(self):
"""Returns list of all variable names in this model.
@ -826,9 +832,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
input_fn=input_fn or default_input_fn,
input_feature_key=input_feature_key,
use_deprecated_input_fn=use_deprecated_input_fn,
signature_fn=(
signature_fn or export.classification_signature_fn_with_prob),
prediction_key=head_lib.PredictionKey.PROBABILITIES,
signature_fn=(signature_fn or
export.classification_signature_fn_with_prob),
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
@ -1041,10 +1047,11 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
head=head,
config=config,
feature_engineering_fn=feature_engineering_fn,
default_prediction_key=head_lib.PredictionKey.SCORES,
default_prediction_key=prediction_key.PredictionKey.SCORES,
enable_centered_bias=enable_centered_bias)
def _get_predict_ops(self, features):
"""See base class."""
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
head_lib.PredictionKey.SCORES]
return super(
DNNLinearCombinedRegressor,
self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]

View File

@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
from tensorflow.contrib.learn.python.learn.estimators import metric_key
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
result = _make_metrics_ops(all_metrics, features, labels,
model_fn_ops.predictions)
if 'loss' not in result:
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
if metric_key.MetricKey.LOSS not in result:
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
model_fn_ops.loss)
return result
def _get_predict_ops(self, features):

View File

@ -24,6 +24,8 @@ from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import metric_key
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python import summary
from tensorflow.python.framework import ops
@ -388,17 +390,17 @@ class _RegressionHead(_Head):
def _logits_to_prediction(self, logits=None):
predictions = {}
if self.logits_dimension == 1:
predictions[PredictionKey.SCORES] = array_ops.squeeze(
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
logits, squeeze_dims=[1])
else:
predictions[PredictionKey.SCORES] = logits
predictions[prediction_key.PredictionKey.SCORES] = logits
return predictions
# pylint: disable=undefined-variable
def _create_signature_fn(self):
def _regression_signature_fn(examples, unused_features, predictions):
if isinstance(predictions, dict):
score = predictions[PredictionKey.SCORES]
score = predictions[prediction_key.PredictionKey.SCORES]
else:
score = predictions
@ -409,11 +411,12 @@ class _RegressionHead(_Head):
return _regression_signature_fn
def _default_metric(self):
return {_head_prefixed(self._head_name, MetricKey.LOSS):
_weighted_average_loss_metric_spec(self._eval_loss_fn,
PredictionKey.SCORES,
self._label_name,
self._weight_column_name)}
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
prediction_key.PredictionKey.SCORES,
self._label_name,
self._weight_column_name)}
class _MultiClassHead(_Head):
@ -530,12 +533,16 @@ class _MultiClassHead(_Head):
return self._logits_to_prediction(logits)
def _logits_to_prediction(self, logits=None):
predictions = {PredictionKey.LOGITS: logits}
# pylint: disable=missing-docstring
predictions = {prediction_key.PredictionKey.LOGITS: logits}
if self.logits_dimension == 1:
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
logits)
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
logits)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
logits, 1)
return predictions
@ -546,8 +553,9 @@ class _MultiClassHead(_Head):
if isinstance(predictions, dict):
default_signature = exporter.classification_signature(
input_tensor=examples,
classes_tensor=predictions[PredictionKey.CLASSES],
scores_tensor=predictions[PredictionKey.PROBABILITIES])
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
scores_tensor=predictions[
prediction_key.PredictionKey.PROBABILITIES])
else:
default_signature = exporter.classification_signature(
input_tensor=examples,
@ -558,44 +566,49 @@ class _MultiClassHead(_Head):
return _classification_signature_fn
def _default_metric(self):
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
_weighted_average_loss_metric_spec(self._eval_loss_fn,
PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
PredictionKey.CLASSES, self._label_name,
prediction_key.PredictionKey.CLASSES,
self._label_name,
self._weight_column_name))
if self.logits_dimension == 1:
def _add_binary_metric(metric_key, metric_fn):
metrics[_head_prefixed(self._head_name, metric_key)] = (
def _add_binary_metric(key, metric_fn):
metrics[_head_prefixed(self._head_name, key)] = (
metric_spec.MetricSpec(metric_fn,
PredictionKey.LOGISTIC,
prediction_key.PredictionKey.LOGISTIC,
self._label_name,
self._weight_column_name))
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
_add_binary_metric(
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
_add_binary_metric(
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
# Also include the streaming mean of the label as an accuracy baseline, as
# a reminder to users.
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
_add_binary_metric(
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
_add_binary_metric(MetricKey.AUC, _streaming_auc)
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
for threshold in self._thresholds:
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
_accuracy_at_threshold(threshold))
# Precision for positive examples.
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
_streaming_at_threshold(
metrics_lib.streaming_precision_at_thresholds,
threshold),)
# Recall for positive examples.
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
_streaming_at_threshold(
metrics_lib.streaming_recall_at_thresholds,
threshold))
@ -635,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
def _logits_to_prediction(self, logits=None):
predictions = {}
predictions[PredictionKey.LOGITS] = logits
predictions[prediction_key.PredictionKey.LOGITS] = logits
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
logits, 1)
return predictions
def _default_metric(self):
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
_weighted_average_loss_metric_spec(self._eval_loss_fn,
PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
PredictionKey.CLASSES, self._label_name,
prediction_key.PredictionKey.CLASSES,
self._label_name,
self._weight_column_name))
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
return metrics
@ -674,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
thresholds=thresholds)
def _logits_to_prediction(self, logits=None):
predictions = {PredictionKey.LOGITS: logits}
predictions = {prediction_key.PredictionKey.LOGITS: logits}
if self.logits_dimension == 1:
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
logits)
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
logits)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
math_ops.greater(logits, 0))
return predictions
@ -849,23 +867,3 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
return array_ops.squeeze(precision_tensor), update_op
return _streaming_metrics
class PredictionKey(object):
CLASSES = "classes"
PROBABILITIES = "probabilities"
LOGITS = "logits"
LOGISTIC = "logistic"
SCORES = "scores"
class MetricKey(object):
LOSS = "loss"
AUC = "auc"
PREDICTION_MEAN = "labels/prediction_mean"
LABEL_MEAN = "labels/actual_label_mean"
ACCURACY = "accuracy"
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
RECALL_MEAN = "recall/positive_threshold_%f_mean"

View File

@ -32,6 +32,7 @@ from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
from tensorflow.python.framework import dtypes
@ -267,21 +268,18 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
Example:
```python
education = sparse_column_with_hash_bucket(column_name="education",
hash_bucket_size=1000)
occupation = sparse_column_with_hash_bucket(column_name="occupation",
hash_bucket_size=1000)
sparse_column_a = sparse_column_with_hash_bucket(...)
sparse_column_b = sparse_column_with_hash_bucket(...)
education_x_occupation = crossed_column(columns=[education, occupation],
hash_bucket_size=10000)
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
# Estimator using the default optimizer.
estimator = LinearClassifier(
feature_columns=[occupation, education_x_occupation])
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
# Or estimator using the FTRL optimizer with regularization.
estimator = LinearClassifier(
feature_columns=[occupation, education_x_occupation],
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
optimizer=tf.train.FtrlOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
@ -289,7 +287,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
# Or estimator using the SDCAOptimizer.
estimator = LinearClassifier(
feature_columns=[occupation, education_x_occupation],
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
example_id_column='example_id',
num_loss_partitions=...,
@ -465,13 +463,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
as_iterable=False)
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
"""Runs inference to determine the predicted class."""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[head_lib.PredictionKey.CLASSES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.CLASSES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
return preds[head_lib.PredictionKey.CLASSES]
return _as_iterable(preds, output=key)
return preds[key]
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -479,14 +480,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
as_iterable=True):
"""Runs inference to determine the class probability predictions."""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[
head_lib.PredictionKey.PROBABILITIES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.PROBABILITIES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
return preds[head_lib.PredictionKey.PROBABILITIES]
return _as_iterable(preds, output=key)
return preds[key]
def get_variable_names(self):
return self._estimator.get_variable_names()
@ -512,9 +515,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
input_fn=input_fn or default_input_fn,
input_feature_key=input_feature_key,
use_deprecated_input_fn=use_deprecated_input_fn,
signature_fn=(
signature_fn or export.classification_signature_fn_with_prob),
prediction_key=head_lib.PredictionKey.PROBABILITIES,
signature_fn=(signature_fn or
export.classification_signature_fn_with_prob),
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
@ -561,16 +564,13 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
Example:
```python
education = sparse_column_with_hash_bucket(column_name="education",
hash_bucket_size=1000)
occupation = sparse_column_with_hash_bucket(column_name="occupation",
hash_bucket_size=1000)
sparse_column_a = sparse_column_with_hash_bucket(...)
sparse_column_b = sparse_column_with_hash_bucket(...)
education_x_occupation = crossed_column(columns=[education, occupation],
hash_bucket_size=10000)
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
estimator = LinearRegressor(
feature_columns=[occupation, education_x_occupation])
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
# Input builders
def input_fn_train: # returns x, y
@ -731,13 +731,16 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
as_iterable=False)
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
"""Runs inference to determine the predicted class."""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[head_lib.PredictionKey.SCORES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.SCORES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.SCORES)
return preds[head_lib.PredictionKey.SCORES]
return _as_iterable(preds, output=key)
return preds[key]
def get_variable_names(self):
return self._estimator.get_variable_names()
@ -764,7 +767,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
input_feature_key=input_feature_key,
use_deprecated_input_fn=use_deprecated_input_fn,
signature_fn=(signature_fn or export.regression_signature_fn),
prediction_key=head_lib.PredictionKey.SCORES,
prediction_key=prediction_key.PredictionKey.SCORES,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)

View File

@ -0,0 +1,30 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Enum for metric keys."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MetricKey(object):
LOSS = "loss"
AUC = "auc"
PREDICTION_MEAN = "labels/prediction_mean"
LABEL_MEAN = "labels/actual_label_mean"
ACCURACY = "accuracy"
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
RECALL_MEAN = "recall/positive_threshold_%f_mean"

View File

@ -0,0 +1,26 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Enum for model prediction keys."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class PredictionKey(object):
CLASSES = "classes"
PROBABILITIES = "probabilities"
LOGITS = "logits"
LOGISTIC = "logistic"
SCORES = "scores"

View File

@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import linear
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
@ -188,13 +189,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
as_iterable=False)
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
"""Runs inference to determine the predicted class."""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[head_lib.PredictionKey.CLASSES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.CLASSES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
return preds[head_lib.PredictionKey.CLASSES]
return _as_iterable(preds, output=key)
return preds[key]
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -202,14 +206,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
as_iterable=True):
"""Runs inference to determine the class probability predictions."""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[
head_lib.PredictionKey.PROBABILITIES],
as_iterable=as_iterable)
key = prediction_key.PredictionKey.PROBABILITIES
preds = self._estimator.predict(
x=x,
input_fn=input_fn,
batch_size=batch_size,
outputs=[key],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
return preds[head_lib.PredictionKey.PROBABILITIES]
return _as_iterable(preds, output=key)
return preds[key]
# pylint: enable=protected-access
def get_variable_names(self):

View File

@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
@ -77,7 +78,8 @@ def get_summary_writer(logdir):
def _make_saver(graph, keep_checkpoint_max=5):
vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
if vars_to_save:
return tf_saver.Saver(vars_to_save,
sharded=True,
@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
with graph.as_default() as g:
with tf_session.Session('') as session:
session.run(
resources.initialize_resources(resources.shared_resources() +
resources.local_resources()))
if restore_checkpoint_path:
_restore_from_checkpoint(session, g, restore_checkpoint_path)
else:

View File

@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
pass
self.assertTrue(request_stop.called)
def test_run_feeds_iter_calls_resources_init(self):
with tf.Graph().as_default() as g:
in0, _, _ = self._build_inference_graph()
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
resources.register_resource(
handle=handle,
create_op=test_ops.resource_create_op(handle),
is_initialized_op=test_ops.resource_initialized_op(handle))
for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
feed_dicts=[{}]):
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
def test_infer_different_default_graph(self):
with self.test_session():
self._assert_ckpt(self._output_dir, False)

View File

@ -1,7 +1,8 @@
### TensorFlow Makefile
The recommended way to build TensorFlow from source is using the Bazel
open-source build system. Sometimes this isn't possible.
open-source build system. Sometimes this isn't possible. For example,
if you are building for iOS, you currently need to use the Makefile.
- The build system may not have the RAM or processing power to support Bazel.
- Bazel or its dependencies may not be available.

View File

@ -43,6 +43,13 @@ tensorflow/core/kernels/sequence_ops.cc
tensorflow/core/kernels/sendrecv_ops.cc
tensorflow/core/kernels/scatter_op.cc
tensorflow/core/kernels/scatter_functor.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
tensorflow/core/kernels/scatter_nd_op.cc
tensorflow/core/kernels/save_restore_tensor.cc
tensorflow/core/kernels/save_restore_v2_ops.cc
tensorflow/core/kernels/save_op.cc

View File

@ -763,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
computes the area under a discretized curve of precision versus recall values
(computed using the aforementioned variables). The `num_thresholds` variable
controls the degree of discretization with larger numbers of thresholds more
closely approximating the true AUC.
closely approximating the true AUC. The quality of the approximation may vary
dramatically depending on `num_thresholds`.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
approximation may be poor if this is not the case.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the `auc`.

View File

@ -15,9 +15,16 @@ py_library(
"python/training/resample.py",
"python/training/sampling_ops.py",
"python/training/sequence_queueing_state_saver.py",
"python/training/training.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:framework",
"//tensorflow/python:ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
],
)
py_test(
@ -98,6 +105,19 @@ py_test(
],
)
py_test(
name = "training_test",
size = "large",
srcs = ["python/training/training_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [
":training_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
from tensorflow.contrib.training.python.training.resample import *
from tensorflow.contrib.training.python.training.sampling_ops import *
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
from tensorflow.contrib.training.python.training.training import create_train_op
from tensorflow.contrib.training.python.training.training import multiply_gradients
from tensorflow.contrib.training.python.training.training import train
from tensorflow.python.util.all_util import make_all
__all__ = make_all(__name__)

View File

@ -0,0 +1,316 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains various routines and helper functions for training models.
TODO(nsilberman): Port documentation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.python import summary
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import optimizer as tf_optimizer
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
# multiply_gradients into contrib/summaries and contrib/optimizers.py
__all__ = [
'add_gradients_summaries',
'clip_gradient_norms',
'create_train_op',
'multiply_gradients',
'train',
]
def add_gradients_summaries(grads_and_vars):
"""Add summaries to gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
Returns:
The list of created summaries.
"""
summaries = []
for grad, var in grads_and_vars:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values
else:
grad_values = grad
summaries.append(summary.histogram_summary(
var.op.name + ':gradient', grad_values))
summaries.append(summary.histogram_summary(
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
else:
logging.info('Var %s has no gradient', var.op.name)
return summaries
def clip_gradient_norms(gradients_to_variables, max_norm):
"""Clips the gradients by the given value.
Args:
gradients_to_variables: A list of gradient to variable pairs (tuples).
max_norm: the maximum norm value.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
def multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
gradient_multipliers: A map from either `Variables` or `Variable` op names
to the coefficient by which the associated gradient should be scaled.
Returns:
The updated list of gradient to variable pairs.
Raises:
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
is empty or None or if `gradient_multipliers` is not a dictionary.
"""
if not isinstance(grads_and_vars, list):
raise ValueError('`grads_and_vars` must be a list.')
if not gradient_multipliers:
raise ValueError('`gradient_multipliers` is empty.')
if not isinstance(gradient_multipliers, dict):
raise ValueError('`gradient_multipliers` must be a dict.')
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if var in gradient_multipliers or var.op.name in gradient_multipliers:
key = var if var in gradient_multipliers else var.op.name
if grad is None:
raise ValueError('Requested multiple of `None` gradient.')
if isinstance(grad, ops.IndexedSlices):
tmp = grad.values * constant_op.constant(
gradient_multipliers[key], dtype=grad.dtype)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad *= constant_op.constant(
gradient_multipliers[key], dtype=grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def create_train_op(total_loss,
optimizer,
global_step=None,
update_ops=None,
variables_to_train=None,
transform_grads_fn=None,
summarize_gradients=False,
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False):
"""Creates an `Operation` that evaluates the gradients and returns the loss.
Args:
total_loss: A `Tensor` representing the total loss.
optimizer: A tf.Optimizer to use for computing the gradients.
global_step: A `Tensor` representing the global step variable. If left as
`None`, then slim.variables.global_step() is used.
update_ops: An optional list of updates to execute. If `update_ops` is
`None`, then the update ops are set to the contents of the
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
a warning will be displayed.
variables_to_train: an optional list of variables to train. If None, it will
default to all tf.trainable_variables().
transform_grads_fn: A function which takes a single argument, a list of
gradient to variable pairs (tuples), performs any requested gradient
updates, such as gradient clipping or multipliers, and returns the updated
list.
summarize_gradients: Whether or not add summaries for each gradient.
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: Whether or not to try colocating the gradients
with the ops that generated them.
Returns:
A `Tensor` that when evaluated, computes the gradients and returns the total
loss value.
"""
if global_step is None:
global_step = variables.get_or_create_global_step()
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
if update_ops is None:
update_ops = global_update_ops
else:
update_ops = set(update_ops)
if not global_update_ops.issubset(update_ops):
logging.warning('update_ops in create_train_op does not contain all the '
' update_ops in GraphKeys.UPDATE_OPS')
# Make sure update_ops are computed before total_loss.
if update_ops:
with ops.control_dependencies(update_ops):
barrier = control_flow_ops.no_op(name='update_barrier')
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
if variables_to_train is None:
# Default to tf.trainable_variables()
variables_to_train = tf_variables.trainable_variables()
else:
# Make sure that variables_to_train are in tf.trainable_variables()
for v in variables_to_train:
assert v in tf_variables.trainable_variables()
assert variables_to_train
# Create the gradients. Note that apply_gradients adds the gradient
# computation to the current graph.
grads = optimizer.compute_gradients(
total_loss,
variables_to_train,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops)
if transform_grads_fn:
grads = transform_grads_fn(grads)
# Summarize gradients.
if summarize_gradients:
with ops.name_scope('summarize_grads'):
add_gradients_summaries(grads)
# Create gradient updates.
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
with ops.name_scope('train_op'):
# Make sure total_loss is valid.
total_loss = array_ops.check_numerics(total_loss,
'LossTensor is inf or nan')
# Ensure the train_tensor computes grad_updates.
return control_flow_ops.with_dependencies([grad_updates], total_loss)
def train(
train_op,
logdir,
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
config=None):
"""Runs the training loop.
Args:
train_op: A `Tensor` that, when executed, will apply the gradients and
return the loss value.
logdir: The directory where the graph and checkpoints are saved.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
config: An instance of `tf.ConfigProto`.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
`save_summaries_steps` are `None.
"""
# TODO(nsilberman): move this logic into monitored_session.py
scaffold = scaffold or monitored_session.Scaffold()
hooks = hooks or []
if is_chief:
session_creator = monitored_session.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=logdir,
master=master,
config=config)
if chief_only_hooks:
hooks.extend(chief_only_hooks)
hooks.append(basic_session_run_hooks.StepCounterHook(
output_dir=logdir))
if save_summaries_steps:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_summaries_steps is None')
hooks.append(basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=logdir))
if save_checkpoint_secs:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_checkpoint_secs is None')
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
else:
session_creator = monitored_session.WorkerSessionCreator(
scaffold=scaffold, master=master, config=config)
with monitored_session.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
loss = None
while not session.should_stop():
loss = session.run(train_op)
return loss

View File

@ -0,0 +1,514 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.training.training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
def logistic_classifier(inputs):
return tf.contrib.layers.fully_connected(
inputs, 1, activation_fn=tf.sigmoid)
def batchnorm_classifier(inputs):
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
class CreateTrainOpTest(tf.test.TestCase):
def setUp(self):
np.random.seed(0)
# Create an easy training set:
self._inputs = np.random.rand(16, 4).astype(np.float32)
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
def testUseUpdateOps(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
expected_mean = np.mean(self._inputs, axis=(0))
expected_var = np.var(self._inputs, axis=(0))
tf_predictions = batchnorm_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables_by_name(
'moving_variance')[0]
with tf.Session() as sess:
# Initialize all variables
sess.run(tf.initialize_all_variables())
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
for _ in range(10):
sess.run([train_op])
mean = moving_mean.eval()
variance = moving_variance.eval()
# After 10 updates with decay 0.1 moving_mean == expected_mean and
# moving_variance == expected_var.
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_var)
def testEmptyUpdateOps(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = batchnorm_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(
total_loss, optimizer, update_ops=[])
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables_by_name(
'moving_variance')[0]
with tf.Session() as sess:
# Initialize all variables
sess.run(tf.initialize_all_variables())
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
for _ in range(10):
sess.run([train_op])
mean = moving_mean.eval()
variance = moving_variance.eval()
# Since we skip update_ops the moving_vars are not updated.
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
class TrainBNClassifierTest(tf.test.TestCase):
def setUp(self):
# Create an easy training set:
np.random.seed(0)
self._inputs = np.zeros((16, 4))
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
for i in range(16):
j = int(2 * self._labels[i] + np.random.randint(0, 2))
self._inputs[i, j] = 1
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = batchnorm_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(
total_loss, optimizer)
loss = tf.contrib.training.train(
train_op, self._logdir, hooks=[
tf.train.StopAtStepHook(num_steps=300)
])
self.assertLess(loss, .1)
class TrainTest(tf.test.TestCase):
def setUp(self):
# Create an easy training set:
np.random.seed(0)
self._inputs = np.zeros((16, 4))
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
for i in range(16):
j = int(2 * self._labels[i] + np.random.randint(0, 2))
self._inputs[i, j] = 1
def testCanAchieveZeroLoss(self):
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = logistic_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.StopAtStepHook(num_steps=300)
])
self.assertIsNotNone(loss)
self.assertLess(loss, .015)
def testTrainWithLocalVariable(self):
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
local_multiplier = tf.contrib.framework.local_variable(1.0)
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(
total_loss, optimizer)
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.StopAtStepHook(num_steps=300)
])
self.assertIsNotNone(loss)
self.assertLess(loss, .015)
def testResumeTrainAchievesRoughlyTheSameLoss(self):
number_of_steps = [300, 1, 5]
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
for i in range(len(number_of_steps)):
with tf.Graph().as_default():
tf.set_random_seed(i)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = logistic_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(
total_loss, optimizer)
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
tf.train.CheckpointSaverHook(
logdir, save_steps=50, saver=saver),
])
self.assertIsNotNone(loss)
self.assertLess(loss, .015)
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = logistic_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
total_loss = tf.contrib.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate)
def transform_grads_fn(grads):
if gradient_multiplier != 1.0:
variables = tf.trainable_variables()
gradient_multipliers = {var: gradient_multiplier for var in variables}
with tf.name_scope('multiply_grads'):
return tf.contrib.training.multiply_gradients(
grads, gradient_multipliers)
else:
return grads
return tf.contrib.training.create_train_op(
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
def testTrainWithInitFromCheckpoint(self):
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
if tf.gfile.Exists(logdir1): # For running on jenkins.
tf.gfile.DeleteRecursively(logdir1)
if tf.gfile.Exists(logdir2): # For running on jenkins.
tf.gfile.DeleteRecursively(logdir2)
# First, train the model one step (make sure the error is high).
with tf.Graph().as_default():
tf.set_random_seed(0)
train_op = self.create_train_op()
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir1, hooks=[
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
tf.train.StopAtStepHook(num_steps=1),
], save_checkpoint_secs=None)
self.assertGreater(loss, .5)
# Next, train the model to convergence.
with tf.Graph().as_default():
tf.set_random_seed(1)
train_op = self.create_train_op()
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir1, hooks=[
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
tf.train.StopAtStepHook(num_steps=300),
], save_checkpoint_secs=None)
self.assertIsNotNone(loss)
self.assertLess(loss, .02)
# Finally, advance the model a single step and validate that the loss is
# still low.
with tf.Graph().as_default():
tf.set_random_seed(2)
train_op = self.create_train_op()
model_variables = tf.all_variables()
model_path = os.path.join(logdir1, 'model.ckpt-300')
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
model_path, model_variables)
def init_fn(_, session):
assign_fn(session)
loss = tf.contrib.training.train(
train_op,
logdir2,
scaffold=tf.train.Scaffold(init_fn=init_fn),
hooks=[tf.train.StopAtStepHook(num_steps=1)])
self.assertIsNotNone(loss)
self.assertLess(loss, .02)
def ModelLoss(self):
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = logistic_classifier(tf_inputs)
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
return tf.contrib.losses.get_total_loss()
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
if tf.gfile.Exists(logdir): # For running on jenkins.
tf.gfile.DeleteRecursively(logdir)
# First, train only the weights of the model.
with tf.Graph().as_default():
tf.set_random_seed(0)
total_loss = self.ModelLoss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
weights = tf.contrib.framework.get_variables_by_name('weights')
train_op = tf.contrib.training.create_train_op(
total_loss,
optimizer,
variables_to_train=weights)
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
tf.train.StopAtStepHook(num_steps=200),
])
self.assertGreater(loss, .015)
self.assertLess(loss, .05)
# Next, train the biases of the model.
with tf.Graph().as_default():
tf.set_random_seed(1)
total_loss = self.ModelLoss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
biases = tf.contrib.framework.get_variables_by_name('biases')
train_op = tf.contrib.training.create_train_op(
total_loss,
optimizer,
variables_to_train=biases)
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
tf.train.StopAtStepHook(num_steps=300),
])
self.assertGreater(loss, .015)
self.assertLess(loss, .05)
# Finally, train both weights and bias to get lower loss.
with tf.Graph().as_default():
tf.set_random_seed(2)
total_loss = self.ModelLoss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir, hooks=[
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
tf.train.StopAtStepHook(num_steps=400),
])
self.assertIsNotNone(loss)
self.assertLess(loss, .015)
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
# First, train only the weights of the model.
with tf.Graph().as_default():
tf.set_random_seed(0)
total_loss = self.ModelLoss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
weights, biases = tf.contrib.framework.get_variables()
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
train_weights = tf.contrib.training.create_train_op(
total_loss, optimizer, variables_to_train=[weights])
train_biases = tf.contrib.training.create_train_op(
total_loss, optimizer, variables_to_train=[biases])
with tf.Session() as sess:
# Initialize the variables.
sess.run(tf.initialize_all_variables())
# Get the intial weights and biases values.
weights_values, biases_values = sess.run([weights, biases])
self.assertGreater(np.linalg.norm(weights_values), 0)
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
# Update weights and biases.
loss = sess.run(train_op)
self.assertGreater(loss, .5)
new_weights, new_biases = sess.run([weights, biases])
# Check that the weights and biases have been updated.
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
weights_values, biases_values = new_weights, new_biases
# Update only weights.
loss = sess.run(train_weights)
self.assertGreater(loss, .5)
new_weights, new_biases = sess.run([weights, biases])
# Check that the weights have been updated, but biases have not.
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
weights_values = new_weights
# Update only biases.
loss = sess.run(train_biases)
self.assertGreater(loss, .5)
new_weights, new_biases = sess.run([weights, biases])
# Check that the biases have been updated, but weights have not.
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
def testTrainWithAlteredGradients(self):
# Use the same learning rate but different gradient multipliers
# to train two models. Model with equivalently larger learning
# rate (i.e., learning_rate * gradient_multiplier) has smaller
# training loss.
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
if tf.gfile.Exists(logdir1):
tf.gfile.DeleteRecursively(logdir1)
if tf.gfile.Exists(logdir2):
tf.gfile.DeleteRecursively(logdir2)
multipliers = [1., 1000.]
number_of_steps = 10
losses = []
learning_rate = 0.001
# First, train the model with equivalently smaller learning rate.
with tf.Graph().as_default():
tf.set_random_seed(0)
train_op = self.create_train_op(
learning_rate=learning_rate,
gradient_multiplier=multipliers[0])
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir1, hooks=[
tf.train.StopAtStepHook(num_steps=number_of_steps),
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
])
losses.append(loss)
self.assertGreater(loss, .5)
# Second, train the model with equivalently larger learning rate.
with tf.Graph().as_default():
tf.set_random_seed(0)
train_op = self.create_train_op(
learning_rate=learning_rate,
gradient_multiplier=multipliers[1])
saver = tf.train.Saver()
loss = tf.contrib.training.train(
train_op, logdir2, hooks=[
tf.train.StopAtStepHook(num_steps=number_of_steps),
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
])
losses.append(loss)
self.assertIsNotNone(loss)
self.assertLess(loss, .5)
# The loss of the model trained with larger learning rate should
# be smaller.
self.assertGreater(losses[0], losses[1])
if __name__ == '__main__':
tf.test.main()

View File

@ -221,13 +221,6 @@ cc_library(
],
)
cc_library(
name = "jpeg",
hdrs = ["lib/jpeg/jpeg_mem.h"],
visibility = ["//visibility:public"],
deps = [":jpeg_internal"],
)
# Test support library needed for all tests
# This is currently public, but may be made internal in the
# future. Try to avoid depending on it.
@ -699,9 +692,9 @@ filegroup(
"platform/cuda.h",
"platform/google/**/*",
"platform/hadoop/**/*",
"platform/jpeg.*",
"platform/png.*",
"platform/gif.*",
"platform/gif.h",
"platform/jpeg.h",
"platform/png.h",
"platform/stream_executor.*",
"platform/windows/**/*",
"user_ops/**/*.cu.cc",
@ -981,7 +974,10 @@ cc_library(
],
exclude = [
"**/*test*",
"lib/gif/**/*",
"lib/jpeg/**/*",
"platform/gif.h",
"platform/jpeg.h",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
"platform/load_library.cc",
@ -998,7 +994,10 @@ cc_library(
],
exclude = [
"**/*test*",
"lib/gif/**/*",
"lib/jpeg/**/*",
"platform/gif.h",
"platform/jpeg.h",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
],
@ -1016,7 +1015,6 @@ cc_library(
hdrs = tf_additional_lib_hdrs() + [
"lib/core/blocking_counter.h",
"lib/core/refcount.h",
"lib/gif/gif_io.h",
"lib/gtl/edit_distance.h",
"lib/gtl/int_type.h",
"lib/gtl/iterator_range.h",
@ -1060,18 +1058,32 @@ cc_library(
],
)
cc_library(
name = "gif_internal",
srcs = [
"lib/gif/gif_io.cc",
"platform/gif.h",
],
hdrs = ["lib/gif/gif_io.h"],
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
":lib",
"//tensorflow/core/platform/default/build_config:gif",
],
)
cc_library(
name = "jpeg_internal",
srcs = glob(
[
"lib/jpeg/*h",
"lib/jpeg/*.cc",
],
exclude = [
"**/*test*",
],
),
hdrs = ["lib/jpeg/jpeg_handle.h"],
srcs = [
"lib/jpeg/jpeg_handle.cc",
"lib/jpeg/jpeg_mem.cc",
"platform/jpeg.h",
],
hdrs = [
"lib/jpeg/jpeg_handle.h",
"lib/jpeg/jpeg_mem.h",
],
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
@ -1541,7 +1553,6 @@ cc_test(
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
data = glob(["lib/jpeg/testdata/*.jpg"]),
deps = [
":jpeg",
":jpeg_internal",
":lib",
":lib_internal",

View File

@ -78,7 +78,7 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
Status DeviceFactory::AddDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
// CPU first.
// CPU first. A CPU device is required.
auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) {
return errors::NotFound(
@ -90,18 +90,11 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
return errors::NotFound("No CPU devices are available in this process");
}
// Then GPU.
auto gpu_factory = GetFactory("GPU");
if (gpu_factory) {
TF_RETURN_IF_ERROR(
gpu_factory->CreateDevices(options, name_prefix, devices));
}
// Then the rest.
// Then the rest (including GPU).
mutex_lock l(*get_device_factory_lock());
for (auto& p : device_factories()) {
auto factory = p.second.factory.get();
if (factory != cpu_factory && factory != gpu_factory) {
if (factory != cpu_factory) {
TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
}
}

View File

@ -282,6 +282,7 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
mu_.unlock();
SchedClosure([session, req, resp, done]() {
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
@ -290,7 +291,22 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
}
done(status);
});
}
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp, MyClosure done) {
mu_.lock();
MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
if (session == nullptr) {
mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
mu_.unlock();
SchedClosure([this, session, req, resp, done]() {
done(session->PartialRunSetup(req, resp));
});
}
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
@ -303,6 +319,7 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
mu_.unlock();
SchedClosure([this, start_time, session, opts, req, resp, done]() {
Status status = session->Run(opts, req, resp);
@ -312,7 +329,6 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
++step_count_;
});
mu_.unlock();
}
void Master::CloseSession(const CloseSessionRequest* req,

View File

@ -46,6 +46,9 @@ class Master {
void ExtendSession(const ExtendSessionRequest* req,
ExtendSessionResponse* resp, MyClosure done);
void PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp, MyClosure done);
void RunStep(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp, MyClosure done);

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
@ -37,6 +38,12 @@ class MasterInterface {
const ExtendSessionRequest* request,
ExtendSessionResponse* response) = 0;
virtual Status PartialRunSetup(CallOptions* call_options,
const PartialRunSetupRequest* request,
PartialRunSetupResponse* response) {
return errors::Unimplemented("Partial run not implemented for this master");
}
virtual Status RunStep(CallOptions* call_options,
const RunStepRequest* request,
RunStepResponse* response) = 0;

View File

@ -50,18 +50,6 @@ limitations under the License.
namespace tensorflow {
// A little bit of per-step state.
struct PerStepState {
bool collect_costs = false;
bool collect_timeline = false;
bool collect_rpcs = false;
Microseconds start_micros = Microseconds(0);
Microseconds end_micros = Microseconds(0);
std::vector<StepStats> step_stats; // per partition
StepStats rpc_stats; // for RPC layer
CostGraphDef cost_graph;
};
// MasterSession wraps SimpleClientGraph in a reference counted object.
// This way, MasterSession can clear up the cache mapping Run requests to
// compiled graphs while the compiled graph is still being used.
@ -72,15 +60,38 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
std::unique_ptr<SimpleClientGraph> cg,
const SessionOptions& session_opts,
StatsPublisherFactory stats_publisher_factory)
StatsPublisherFactory stats_publisher_factory,
SimpleGraphExecutionState* execution_state, bool is_partial)
: session_handle_(handle),
client_graph_(std::move(cg)),
bopts_(bopts),
session_opts_(session_opts) {
session_opts_(session_opts),
is_partial_(is_partial) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph_->graph.num_node_ids();
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
// If this is a partial run we need to initialize a name to node map for
// testing that fetches are reachable.
if (is_partial) {
std::unordered_set<StringPiece, StringPiece::Hasher> names;
for (const string& input : bopts.feed_endpoints) {
TensorId id(ParseTensorName(input));
names.emplace(id.first);
}
for (const string& output : bopts.fetch_endpoints) {
TensorId id(ParseTensorName(output));
names.emplace(id.first);
}
// We use the graph from the execution_state because we want the graph
// nodes before they are rewritten replaced by the rewriter.
for (Node* n : execution_state->full_graph()->nodes()) {
if (names.count(n->name()) > 0) {
name_to_node_.insert({n->name(), n});
}
}
}
}
~ReffedClientGraph() override { DeregisterPartitions(); }
@ -171,7 +182,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
SimpleGraphExecutionState* execution_state,
PerStepState* pss, CallOptions* opts,
const RunStepRequest& req, RunStepResponse* resp,
CancellationManager* cm);
CancellationManager* cm, const bool is_last_partial_run);
// Calls workers to cleanup states for the step "step_id". Calls
// `done` when all cleanup RPCs have completed.
@ -185,6 +196,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
void ProcessDeviceStats(ProfileHandler* ph,
const SimpleGraphExecutionState* execution_state,
const DeviceStepStats& ds, bool is_rpc);
// Checks that the requested fetches can be computed from the provided feeds.
Status CheckFetches(const RunStepRequest& req, const RunState* run_state,
SimpleGraphExecutionState* execution_state);
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
int64 tot = 0;
@ -209,6 +223,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
std::unordered_set<const Node*> nodes_needing_input_mapping_;
BuildGraphOptions bopts_;
const SessionOptions session_opts_;
const bool is_partial_;
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node_;
// Graph partitioned into per-location subgraphs.
struct Part {
@ -483,15 +499,14 @@ class RunManyGraphs {
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
SimpleGraphExecutionState* execution_state, PerStepState* pss,
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
CancellationManager* cm) {
CancellationManager* cm, const bool is_last_partial_run) {
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
<< execution_count;
// Builds an index for feeds provided by the client.
// Build an index for feeds provided by the client.
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
feeds(3);
@ -524,26 +539,64 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* c = calls.get(i);
if (is_partial_) {
c->req.set_is_partial(is_partial_);
c->req.set_is_last_partial_run(is_last_partial_run);
}
c->req.set_graph_handle(part.graph_handle);
c->req.set_step_id(step_id);
*c->req.mutable_exec_opts() = exec_opts;
// If any feeds are provided, send the feed values together
// in the RunGraph request.
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
const TensorProto* val = feeds[feed];
if (val == nullptr) {
return errors::InvalidArgument("No feed is provided for feed=", feed,
", key=", key);
// In the partial case, we only want to include feeds provided in the req.
// In the non-partial case, all feeds in the request are in the part.
// We keep these as separate paths for now, to ensure we aren't
// inadvertently slowing down the normal run path.
if (is_partial_) {
for (const auto& feed : req.feed()) {
const string& name = feed.name();
auto iter = part.feed_key.find(name);
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
}
const string& key = iter->second;
const TensorProto* val = feeds[name];
if (val == nullptr) {
return errors::InvalidArgument("No feed is provided for feed=", name,
", key=", key);
}
auto* send = c->req.add_send();
send->set_key(key);
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
}
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
// For now, we just iterate through partitions to find the matching key.
for (const auto& req_fetch : req.fetch()) {
for (const auto& key_fetch : part.key_fetch) {
if (key_fetch.second == req_fetch) {
c->req.add_recv_key(key_fetch.first);
break;
}
}
}
} else {
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
const TensorProto* val = feeds[feed];
if (val == nullptr) {
return errors::InvalidArgument("No feed is provided for feed=", feed,
", key=", key);
}
auto* send = c->req.add_send();
send->set_key(key);
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
c->req.add_recv_key(key);
}
auto* send = c->req.add_send();
send->set_key(key);
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
c->req.add_recv_key(key);
}
}
@ -762,6 +815,64 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
}
}
// TODO(suharshs): Merge with CheckFetches in DirectSession.
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
// on once at setup time to prevent us from computing the dependencies
// everytime.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequest& req, const RunState* run_state,
SimpleGraphExecutionState* execution_state) {
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
for (const string& feed : run_state->pending_inputs) {
TensorId id(ParseTensorName(feed));
auto it = name_to_node_.find(id.first);
if (it == name_to_node_.end()) {
return errors::NotFound("Feed ", feed, ": not found");
}
pending_feeds.insert(id);
}
for (const auto& feed : req.feed()) {
TensorId id(ParseTensorName(feed.name()));
pending_feeds.erase(id);
}
// Initialize the stack with the fetch nodes.
std::vector<const Node*> stack;
for (const string& fetch : req.fetch()) {
TensorId id(ParseTensorName(fetch));
auto it = name_to_node_.find(id.first);
if (it == name_to_node_.end()) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
stack.push_back(it->second);
}
// Any tensor needed for fetches can't be in pending_feeds.
// We need to use the original full graph from execution state.
const Graph* graph = execution_state->full_graph();
std::vector<bool> visited(graph->num_node_ids(), false);
while (!stack.empty()) {
const Node* n = stack.back();
stack.pop_back();
for (const Edge* in_edge : n->in_edges()) {
const Node* in_node = in_edge->src();
if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
return errors::InvalidArgument("Fetch ", in_node->name(), ":",
in_edge->src_output(),
" can't be computed from the feeds"
" that have been fed so far.");
}
if (!visited[in_node->id()]) {
visited[in_node->id()] = true;
stack.push_back(in_node);
}
}
}
return Status::OK();
}
// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
@ -803,6 +914,23 @@ void BuildBuildGraphOptions(const RunStepRequest& req,
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
BuildGraphOptions* opts) {
for (const auto& feed : req.feed()) {
opts->feed_endpoints.push_back(feed);
}
for (const auto& fetch : req.fetch()) {
opts->fetch_endpoints.push_back(fetch);
}
for (const auto& target : req.target()) {
opts->target_nodes.push_back(target);
}
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
uint64 h = 0x2b992ddfa23249d6ull;
for (const string& name : opts.feed_endpoints) {
@ -927,11 +1055,9 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK();
}
Status MasterSession::StartStep(const RunStepRequest& req,
BuildGraphOptions* opts, int64* count,
ReffedClientGraph** rcg) {
BuildBuildGraphOptions(req, opts);
const uint64 hash = HashBuildGraphOptions(*opts);
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts);
ReffedClientGraph* to_unref = nullptr;
{
mutex_lock l(mu_);
@ -944,12 +1070,12 @@ Status MasterSession::StartStep(const RunStepRequest& req,
// We have not seen this subgraph before. Build the subgraph and
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(*opts);
<< BuildGraphOptionsString(opts);
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
auto entry =
new ReffedClientGraph(handle_, *opts, std::move(client_graph),
session_opts_, stats_publisher_factory_);
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial);
iter = runs_.insert({hash, entry}).first;
auto obs_iter = obsolete_.find(hash);
if (obs_iter != obsolete_.end()) {
@ -979,6 +1105,47 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
std::vector<string> inputs, outputs, targets;
for (const auto& feed : req->feed()) {
inputs.push_back(feed);
}
for (const auto& fetch : req->fetch()) {
outputs.push_back(fetch);
}
for (const auto& target : req->target()) {
targets.push_back(target);
}
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
ReffedClientGraph* rcg = nullptr;
int64 count = 0;
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
{
mutex_lock l(mu_);
partial_runs_.emplace(
std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
}
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
resp->set_partial_run_handle(handle);
return Status::OK();
}
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp) {
UpdateLastAccessTime();
@ -986,7 +1153,12 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
mutex_lock l(mu_);
++num_running_;
}
Status status = DoRunWithLocalExecution(opts, req, resp);
Status status;
if (!req->partial_run_handle().empty()) {
status = DoPartialRun(opts, req, resp);
} else {
status = DoRunWithLocalExecution(opts, req, resp);
}
{
mutex_lock l(mu_);
--num_running_;
@ -997,23 +1169,7 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
return status;
}
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
const RunStepRequest* req,
RunStepResponse* resp) {
VLOG(2) << "DoRunWithLocalExecution "
<< "req: " << req->DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
// Prepare.
BuildGraphOptions bgopts;
ReffedClientGraph* rcg = nullptr;
int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so.
PartitionOptions popts;
popts.node_to_loc = SplitByWorker;
@ -1051,12 +1207,136 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(
env_, popts, rcg->client_graph()->flib_def->ToProto()));
return Status::OK();
}
Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp) {
const string& prun_handle = req->partial_run_handle();
RunState* run_state = nullptr;
{
mutex_lock l(mu_);
auto it = partial_runs_.find(prun_handle);
if (it == partial_runs_.end()) {
return errors::InvalidArgument(
"Must run PartialRunSetup before performing partial runs");
}
run_state = it->second.get();
}
// If this is the first partial run, initialize the PerStepState.
if (!run_state->step_started) {
run_state->step_started = true;
PerStepState pss;
auto count = run_state->count;
pss.collect_timeline =
req->options().trace_level() == RunOptions::FULL_TRACE;
// Build the cost model every 'build_cost_model_every' steps after skipping
// an
// initial 'build_cost_model_after' steps.
const int64 build_cost_model_after =
session_opts_.config.graph_options().build_cost_model_after();
const int64 build_cost_model_every =
session_opts_.config.graph_options().build_cost_model();
pss.collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
run_state->step_id, count, req->options());
if (ph) {
pss.collect_timeline = true;
pss.collect_rpcs = ph->should_collect_rpcs();
}
run_state->pss = std::move(pss);
run_state->ph = std::move(ph);
}
// Make sure that this is a new set of feeds that are still pending.
for (const auto& feed : req->feed()) {
auto it = run_state->pending_inputs.find(feed.name());
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument("The feed ", feed.name(),
" had already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
for (const auto& fetch : req->fetch()) {
auto it = run_state->pending_outputs.find(fetch);
if (it == run_state->pending_outputs.end()) {
return errors::InvalidArgument("The fetch ", fetch,
" had already been fetched.");
}
}
// Ensure that the requested fetches can be computed from the provided feeds.
TF_RETURN_IF_ERROR(
run_state->rcg->CheckFetches(*req, run_state, execution_state_.get()));
// Determine if this partial run satisfies all the pending inputs and ouputs.
for (const auto& feed : req->feed()) {
run_state->pending_inputs.erase(feed.name());
}
for (const auto& fetch : req->fetch()) {
run_state->pending_outputs.erase(fetch);
}
bool is_last_partial_run =
(run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
Status s = run_state->rcg->RunPartitions(
env_, run_state->step_id, run_state->count, execution_state_.get(),
&run_state->pss, opts, *req, resp, cancellation_manager_,
is_last_partial_run);
// Delete the run state if there is an error or all fetches are done.
if (!s.ok() || is_last_partial_run) {
ReffedClientGraph* rcg = run_state->rcg;
run_state->pss.end_micros = Env::Default()->NowMicros();
// Schedule post-processing and cleanup to be done asynchronously.
rcg->Ref();
rcg->ProcessStats(env_, run_state->step_id, &run_state->pss,
execution_state_.get(), run_state->ph.get(), *req, resp);
rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
});
}
return s;
}
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
const RunStepRequest* req,
RunStepResponse* resp) {
VLOG(2) << "DoRunWithLocalExecution "
<< "req: " << req->DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
// Prepare.
BuildGraphOptions bgopts;
BuildBuildGraphOptions(*req, &bgopts);
ReffedClientGraph* rcg = nullptr;
int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
// Build the cost model every 'build_cost_model_every' steps after skipping an
@ -1069,15 +1349,16 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
ph = rcg->GetProfileHandler(step_id, count, req->options());
std::unique_ptr<ProfileHandler> ph =
rcg->GetProfileHandler(step_id, count, req->options());
if (ph) {
pss.collect_timeline = true;
pss.collect_rpcs = ph->should_collect_rpcs();
}
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
execution_state_.get(), &pss, opts,
*req, resp, cancellation_manager_));
TF_RETURN_IF_ERROR(
rcg->RunPartitions(env_, step_id, count, execution_state_.get(), &pss,
opts, *req, resp, cancellation_manager_, false));
pss.end_micros = Env::Default()->NowMicros();
@ -1110,4 +1391,22 @@ Status MasterSession::Close() {
return Status::OK();
}
MasterSession::RunState::RunState(const std::vector<string>& input_names,
const std::vector<string>& output_names,
ReffedClientGraph* rcg, const uint64 step_id,
const int64 count)
: rcg(rcg), step_id(step_id), count(count) {
// Initially all the feeds and fetches are pending.
for (auto& name : input_names) {
pending_inputs.emplace(name);
}
for (auto& name : output_names) {
pending_outputs.emplace(name);
}
}
MasterSession::RunState::~RunState() {
if (rcg) rcg->Unref();
}
} // end namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
#include <atomic>
#include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
@ -72,6 +73,10 @@ class MasterSession {
// Extend() may block the caller thread for a long time.
Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
// Setup a partial run call.
Status PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp);
// Run one step.
Status Run(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp);
@ -101,6 +106,8 @@ class MasterSession {
std::atomic_ulong last_access_time_usec_;
std::atomic<int64> partial_run_handle_counter_ = {0};
mutex mu_;
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
int64 graph_version_;
@ -115,6 +122,36 @@ class MasterSession {
RCGMap runs_ GUARDED_BY(mu_);
RCGMap obsolete_ GUARDED_BY(mu_);
struct PerStepState {
bool collect_costs = false;
bool collect_timeline = false;
bool collect_rpcs = false;
Microseconds start_micros = Microseconds(0);
Microseconds end_micros = Microseconds(0);
std::vector<StepStats> step_stats; // per partition
StepStats rpc_stats; // for RPC layer
CostGraphDef cost_graph;
};
struct RunState {
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
int64 count = 0;
PerStepState pss;
std::unique_ptr<ProfileHandler> ph;
bool step_started = false;
RunState(const std::vector<string>& input_names,
const std::vector<string>& output_names, ReffedClientGraph* rcg,
const uint64 step_id, const int64 count);
~RunState();
};
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
GUARDED_BY(mu_);
// Active RunStep calls.
condition_variable num_running_is_zero_;
int32 num_running_ GUARDED_BY(mu_) = 0;
@ -131,14 +168,18 @@ class MasterSession {
// Private dtor. The client must call Close().
virtual ~MasterSession();
Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
int64* count, ReffedClientGraph** graph);
Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp);
void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
};

View File

@ -131,7 +131,7 @@ class OpKernel {
// We allow legacy scalars within Google up until GraphDef version 6.
// TODO(irving): Remove when we can drop support for GraphDef version 5.
bool allow_legacy_scalars() const {
#if defined(PLATFORM_GOOGLE)
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
return graph_def_version_ < 6;
#else
return false;

View File

@ -1136,8 +1136,9 @@ tf_kernel_libraries(
":eigen_helpers",
":image_resizer_state",
"//tensorflow/core:framework",
"//tensorflow/core:gif_internal",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:jpeg",
"//tensorflow/core:jpeg_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -2099,11 +2100,13 @@ tf_kernel_libraries(
"count_up_to_op",
"dense_update_ops",
"scatter_op",
"scatter_nd_op",
"variable_ops",
],
deps = [
":assign_op",
":bounds_check",
":fill_functor",
":scatter_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -2117,6 +2120,7 @@ tf_cc_test(
size = "small",
srcs = ["scatter_op_test.cc"],
deps = [
":fill_functor",
":ops_testutil",
":ops_util",
":scatter_op",
@ -2129,6 +2133,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "scatter_nd_op_test",
size = "small",
srcs = ["scatter_nd_op_test.cc"],
deps = [
":ops_testutil",
":ops_util",
":scatter_nd_op",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_libraries(
name = "string",
prefixes = [
@ -2571,6 +2592,7 @@ filegroup(
"debug_ops.*",
# Ops excluded because they do not build correctly for Android.
# See b/29213790
"scatter_nd_op*",
"sparse_matmul_op.*",
],
),

View File

@ -64,7 +64,7 @@ class TestFileSystem : public NullFileSystem {
std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
float val = 0;
StringPiece scheme, host, path;
ParseURI(fname, &scheme, &host, &path);
io::ParseURI(fname, &scheme, &host, &path);
// For the tests create in-memory regions with float values equal to the
// region name.
if (path == "/2") {

View File

@ -46,25 +46,6 @@ namespace functor {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
// Sample a truncated normal random variable, with mean, stddev, minval, and
// maxval parameters for each batch. Uses two rejection sampling algorithms
// described in http://rd.springer.com/article/10.1007/BF00143942.
//
// Either minval may be -infinity, or maxval may be +infinity. If the interval
// (minval, maxval) is empty, the result is NaN. Large intervals which include
// both tails may have reduced accuracy.
template <typename Device, typename T>
struct TruncatedNormalFunctor {
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output);
};
template <typename T>
struct TruncatedNormalFunctor<CPUDevice, T> {
static const int kMaxIterations = 100;
@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
tensorflow::random::Array<T, 4> z;
tensorflow::random::Array<T, 4> g;
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
for (int64 b = start_batch; b < limit_batch; ++b) {
// We are passed a flat array for each of the parameter tensors.
@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
T plusFactor;
if (normMin < T(0)) {
// normMax > 0 because it is flipped otherwise.
plusFactor = T(0);
} else {
plusFactor = normMin * normMin;
}
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
while (sample < limit_sample) {
const auto rand = dist(&gen_copy);
@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -16,14 +16,35 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/random_distributions.h"
namespace tensorflow {
class OpKernelContext;
namespace functor {
// Sample a truncated normal random variable, with mean, stddev, minval, and
// maxval parameters for each batch. Uses two rejection sampling algorithms
// described in http://rd.springer.com/article/10.1007/BF00143942.
//
// Either minval may be -infinity, or maxval may be +infinity. If the interval
// (minval, maxval) is empty, the result is NaN. Large intervals which include
// both tails may have reduced accuracy.
template <typename Device, typename T>
struct TruncatedNormalFunctor;
struct TruncatedNormalFunctor {
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output);
static const int kMaxIterations = 100;
};
} // namespace functor
} // namespace tensorflow

View File

@ -0,0 +1,214 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
#include <assert.h>
#include <stdio.h>
#include <cmath>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#define UNROLL _Pragma("unroll")
namespace tensorflow {
class OpKernelContext;
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
template <typename T>
__global__ void __launch_bounds__(1024)
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
const T* means, bool single_mean, const T* stddevs,
bool single_stddev, const T* minvals,
bool single_minval, const T* maxvals,
bool single_maxval, int64 kMaxIterations) {
const int32 max_samples_per_item = 2 * kMaxIterations;
// Initial offset as given by CUDA_1D_KERNEL_LOOP.
const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
gen.Skip(max_samples_per_item * initial_offset);
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
Uniform dist;
const int kDistSize = Uniform::kResultElementCount;
const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
// We skip the total number of threads to get to the next element. To produce
// deterministic results between devices, each element in the output array
// skips max_samples_per_item in the generator. Then after generating this
// item, we need to skip the samples for one element for every thread to get
// to the next element that we actually process.
const int32 samples_between_processed_elements =
max_samples_per_item * (gridDim.x * blockDim.x);
CUDA_1D_KERNEL_LOOP(offset, num_elements) {
// Track how many more samples we need to skip before we process the next
// element.
int32 remaining_samples = samples_between_processed_elements;
const int64 batch_id = offset / samples_per_batch;
T mean = means[single_mean ? 0 : batch_id];
const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
T minval = minvals[single_minval ? 0 : batch_id];
T maxval = maxvals[single_maxval ? 0 : batch_id];
// Flip the distribution if we can make the lower bound positive.
T stddev;
if (Eigen::numext::isinf(minval) || maxval < mean) {
// Reverse all calculations. normMin and normMax will be flipped.
// std::swap is a host function (not available in CUDA).
T temp = minval;
minval = maxval;
maxval = temp;
stddev = -input_stddev;
} else {
stddev = input_stddev;
}
// Calculate normalized samples, then scale them.
const T normMin = (minval - mean) / stddev;
const T normMax = (maxval - mean) / stddev;
// Determine the method to use.
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
const T cutoff =
T(2) *
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
// Validate the normalized min and max, because the originals may have been
// flipped already.
if (!(input_stddev > T(0) && normMin < normMax &&
(Eigen::numext::isfinite(normMin) ||
Eigen::numext::isfinite(normMax)))) {
data[offset] = quietNaN;
} else if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
int numIterations = 0;
while (numIterations < kMaxIterations) {
const auto rand = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i++) {
z[i] = rand[i] * diff + normMin;
}
UNROLL for (int i = 0; i < kDistSize; i++) {
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
}
const auto u = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i++) {
if (u[i] <= Eigen::numext::exp(g[i]) ||
numIterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
// sample. Emperically, the probability of accepting each sample
// is at least 50% for typical inputs, so we will always accept
// by 100 iterations.
// This introduces a slight inaccuracy when at least one bound
// is large, minval is negative and maxval is positive.
data[offset] = z[i] * stddev + mean;
// Break out of the nested loop by updating numIterations.
numIterations = kMaxIterations;
break;
} else {
numIterations++;
}
}
}
} else {
// Sample from an exponential distribution with alpha maximizing
// acceptance probability, offset by normMin from the origin.
// Accept only if less than normMax.
const T alpha =
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
int numIterations = 0;
while (numIterations < kMaxIterations) {
auto rand = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i += 2) {
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
const T x = normMin < alpha ? alpha - z : normMin - alpha;
const T g = Eigen::numext::exp(-x * x / 2.0);
const T u = rand[i + 1];
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
data[offset] = z * stddev + mean;
// Break out of the nested loop by updating numIterations.
numIterations = kMaxIterations;
break;
} else {
numIterations++;
}
}
}
}
gen.Skip(remaining_samples);
}
}
// Partial specialization for GPU
template <typename T>
struct TruncatedNormalFunctor<GPUDevice, T> {
static const int kMaxIterations = 100;
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output) {
const auto config = GetCudaLaunchConfig(num_elements, d);
TruncatedNormalKernel<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
gen, output.data(), num_batches, samples_per_batch, num_elements,
means.data(), means.dimension(0) == 1, stddevs.data(),
stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
};
};
// Explicit instantiation of the GPU distributions functors
template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
template struct TruncatedNormalFunctor<GPUDevice, float>;
template struct TruncatedNormalFunctor<GPUDevice, double>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
BM_PTruncatedNormalDev(cpu, 1000, 1000);
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
BM_PTruncatedNormalDev(gpu, 1000, 1000);
BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
} // namespace tensorflow

View File

@ -0,0 +1,402 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/state_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/scatter_nd_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// Check whether updates.shape = indices.shape[0] + params.shape[IXDIM:]
static bool ValidUpdateShape(const TensorShape& params_shape,
const Tensor& indices, const Tensor& updates) {
int64 indices_nd = 1;
if (indices.dims() > 1) {
indices_nd = indices.dim_size(1);
}
for (int d = indices_nd; d < params_shape.dims(); d++) {
if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) {
return false;
}
}
return true;
}
template <typename Index>
static void PrepareAndValidateInputs(OpKernelContext* c,
const TensorShape& params_shape,
const Tensor& indices,
const Tensor& updates, int64* indices_nd,
Index* num_updates, Index* slice_size) {
const TensorShape& indices_shape(indices.shape());
const TensorShape& updates_shape(updates.shape());
OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(params_shape),
errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ",
params_shape.DebugString()));
OP_REQUIRES(c, params_shape.num_elements() >= 0 ||
(indices.NumElements() == 0 && updates.NumElements() == 0),
errors::InvalidArgument(
"Indices and updates specified for empty output", " shape"));
OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0),
errors::InvalidArgument(
"The outermost dimension of updates and indices ",
"must match. Got indices.shape ", indices_shape.DebugString(),
", updates.shape ", updates_shape.DebugString()));
OP_REQUIRES(
c, ValidUpdateShape(params_shape, indices, updates),
errors::InvalidArgument(
"Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ",
"got updates.shape ", updates_shape.DebugString(), ", indices.shape ",
indices_shape.DebugString(), ", params_shape ",
params_shape.DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
errors::InvalidArgument(
"indices has too many elements for ",
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
N_big, " > ", std::numeric_limits<Index>::max()));
OP_REQUIRES(
c, params_shape.dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params_shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params_shape.dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// Calculate the number of dimensions in indices
*indices_nd = 1;
if (indices_shape.dims() > 1) {
*indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
}
// Calculate the number of elements that make up each slice of our updated
// tensor. This allows us to work with flattened tensors and copy over whole
// slices at a time.
Index total_nd = params_shape.dims();
int64 slice_size_big = 1;
for (int64 i = *indices_nd; i < total_nd; ++i) {
slice_size_big *= params_shape.dim_size(i);
}
OP_REQUIRES(c, slice_size_big <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("slice size is too large for indexing: ",
slice_size_big, " > ",
std::numeric_limits<Index>::max()));
*slice_size = static_cast<Index>(slice_size_big);
const int64 safe_indices_nd = (*indices_nd < 1) ? 1 : *indices_nd;
*num_updates = indices_shape.num_elements() / safe_indices_nd;
}
template <typename Device, typename T, typename Index>
class ScatterNdOp : public OpKernel {
public:
explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
}
void Compute(OpKernelContext* c) override {
const Tensor& indices = c->input(0);
const Tensor& updates = c->input(1);
const Tensor& shape_input = c->input(2);
OP_REQUIRES(c, shape_input.dims() == 1,
errors::InvalidArgument("Shape must be a vector"));
auto vec = shape_input.flat<Index>();
TensorShape shape;
TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape);
int64 indices_nd;
Index num_updates;
Index slice_size;
PrepareAndValidateInputs<Index>(c, shape, indices, updates, &indices_nd,
&num_updates, &slice_size);
if (!c->status().ok()) return;
Tensor scratch;
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
Index bad_i = -1;
switch (indices_nd) {
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
Tensor* out = nullptr; \
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
functor::SetZeroFunctor<Device, T> fill; \
fill(c->eigen_device<Device>(), out->flat<T>()); \
if (shape.num_elements() > 0) { \
auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
functor::ScatterNdFunctor<Device, T, Index, \
scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
functor; \
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
output_flat, indices_flat, updates_flat, output_flat); \
} \
} break
PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
PARAMS_CASE(4);
PARAMS_CASE(5);
#undef PARAMS_CASE
default:
OP_REQUIRES(c, false,
errors::InvalidArgument(
"Only indices.shape[-1] values between 0 and 5 "
"are currently supported. Requested rank: ",
indices_nd));
}
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
" = [", str_util::Join(gtl::ArraySlice<Index>(
&indices_flat(bad_i, 0), indices_nd),
", "),
"] does not index into ", shape.DebugString()));
}
};
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op>
class ScatterNdUpdateOp : public OpKernel {
public:
explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType dt_ref = DataTypeToEnum<T>::ref();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
}
void Compute(OpKernelContext* c) override {
if (use_exclusive_lock_) {
// Hold mutex while we apply updates
mutex_lock l(*c->input_ref_mutex(0));
DoCompute(c);
} else {
DoCompute(c);
}
}
private:
bool use_exclusive_lock_;
void DoCompute(OpKernelContext* c) {
Tensor params = c->mutable_input(0, use_exclusive_lock_);
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
const TensorShape& params_shape(params.shape());
int64 indices_nd;
Index num_updates;
Index slice_size;
OP_REQUIRES(c, params.IsInitialized(),
errors::FailedPrecondition("Null ref for params"));
PrepareAndValidateInputs<Index>(c, params_shape, indices, updates,
&indices_nd, &num_updates, &slice_size);
if (!c->status().ok()) return;
Tensor scratch;
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
Index bad_i = -1;
c->forward_ref_input_to_ref_output(0, 0);
switch (indices_nd) {
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
params_flat, indices_flat, updates_flat, params_flat); \
} break
PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
PARAMS_CASE(4);
PARAMS_CASE(5);
#undef PARAMS_CASE
default:
OP_REQUIRES(c, false,
errors::InvalidArgument(
"Only indices.shape[-1] values between 1 and 5 "
"are currently supported. Requested rank: ",
indices_nd));
}
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
" = [", str_util::Join(gtl::ArraySlice<Index>(
&indices_flat(bad_i, 0), indices_nd),
", "),
"] is not in [0, ", params.dim_size(0), ")"));
}
};
#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
REGISTER_KERNEL_BUILDER(Name(name) \
.Device(DEVICE_##dev) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
ScatterNdOp<dev##Device, type, index_type>)
#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
op) \
REGISTER_KERNEL_BUILDER( \
Name(name) \
.Device(DEVICE_##dev) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
ScatterNdUpdateOp<dev##Device, type, index_type, op>)
#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
scatter_nd_op::UpdateOp::SUB); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
scatter_nd_op::UpdateOp::MUL); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
scatter_nd_op::UpdateOp::DIV);
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
scatter_nd_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
REGISTER_SCATTER_ND_UPDATE(type, CPU);
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
REGISTER_SCATTER_ND_UPDATE(type, GPU);
// TODO(simister): Re-enable when GPU support is working.
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU);
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU);
#endif // GOOGLE_CUDA
#undef REGISTER_SCATTER_ND_ADD
#undef REGISTER_SCATTER_ND_ADD_SUB
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
#undef REGISTER_SCATTER_ND_UPDATE
#undef REGISTER_SCATTER_ND_UPDATE_CPU
#undef REGISTER_SCATTER_ND_UPDATE_GPU
#undef REGISTER_SCATTER_ND_KERNEL
#undef REGISTER_SCATTER_ND_KERNEL_INDEX
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T, IXDIM>::Tensor params, \
typename TTypes<Index, 2>::ConstTensor indices, \
typename TTypes<T, 2>::ConstTensor updates); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
DECLARE_GPU_SPECS_OP(T, Index, op, 5)
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64);
// TODO(simister): Re-enable when GPU support is working.
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_OPS
#undef DECLARE_GPU_SPECS_OP
} // namespace functor
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
#define TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/scatter_nd_op.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
class OpKernelContext;
namespace scatter_nd_op {
enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
} // namespace scatter_nd_op
namespace functor {
// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(const Device& d, const Index slice_size,
typename TTypes<Index>::Scalar Tscratch,
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, IXDIM + 1>::Tensor Toutput);
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_SCATTER_ND_OP_H_

View File

@ -0,0 +1,224 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
// Functor definitions for ScatterND ops, must be compilable by nvcc.
#define EIGEN_USE_THREADS
#include <atomic>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/scatter_nd_op.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
class OpKernelContext;
// Specialization of UpdateExecutor to CPU
namespace generator {
template <typename T, typename Index, scatter_nd_op::UpdateOp op>
class UpdateExecutor {
public:
static void Update(T* input, const T* updates, T* output, Index slice_size);
};
template <typename T, typename Index>
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> {
public:
static void Update(T* /* unused */, const T* updates, T* output,
Index slice_size) {
std::copy_n(updates, slice_size, output);
}
};
template <typename T, typename Index>
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> {
public:
static void Update(T* input, const T* updates, T* output, Index slice_size) {
std::transform(input, input + slice_size, updates, output, std::plus<T>());
}
};
template <typename T, typename Index>
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> {
public:
static void Update(T* input, const T* updates, T* output, Index slice_size) {
std::transform(input, input + slice_size, updates, output, std::minus<T>());
}
};
template <typename T, typename Index>
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> {
public:
static void Update(T* input, const T* updates, T* output, Index slice_size) {
std::transform(input, input + slice_size, updates, output,
std::multiplies<T>());
}
};
template <typename T, typename Index>
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> {
public:
static void Update(T* input, const T* updates, T* output, Index slice_size) {
std::transform(input, input + slice_size, updates, output,
std::divides<T>());
}
};
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
class ScatterNdSliceGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator(
const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, IXDIM + 1>::Tensor Toutput,
std::atomic<Index>* error_loc)
: slice_size_(slice_size),
Tparams_(Tparams),
Tindices_(Tindices),
Tupdates_(Tupdates),
Toutput_(Toutput),
error_loc_(error_loc) {}
EIGEN_DEVICE_FUNC bool GenerateIndices(
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
(*ix)[IXDIM] = 0;
bool out_of_bounds = false;
for (int i = 0; i < IXDIM; ++i) {
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
(*ix)[i] = ix_i;
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
}
return out_of_bounds;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
auto loc = loc_array[0];
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params;
Eigen::array<Eigen::DenseIndex, 2> ix_updates;
ix_updates[0] = loc;
ix_updates[1] = 0;
const bool out_of_bounds = GenerateIndices(loc, &ix_params);
if (TF_PREDICT_FALSE(out_of_bounds)) {
error_loc_->store(loc);
} else {
UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params),
&Tupdates_(ix_updates),
&Toutput_(ix_params), slice_size_);
}
return static_cast<int32>(0); // Return something...
}
protected:
const Index slice_size_;
mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_;
const typename TTypes<Index, 2>::ConstTensor Tindices_;
const typename TTypes<T, 2>::ConstTensor Tupdates_;
mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_;
std::atomic<Index>* error_loc_;
};
} // namespace generator
namespace functor {
// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
Index operator()(const CPUDevice& d, const Index slice_size,
typename TTypes<Index>::Scalar Tscratch,
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, IXDIM + 1>::Tensor Toutput) {
std::atomic<Index> error_loc(-1);
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
#else
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
broadcast_dims.set(0, batch_size);
#endif
generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator(
slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc);
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(generator)
.sum();
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
return error_loc.load();
}
};
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index \
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const CPUDevice& d, const Index slice_size, \
typename TTypes<Index>::Scalar Tscratch, \
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Toutput)
#define REGISTER_SCATTER_ND_INDEX(type, op) \
REGISTER_SCATTER_ND_FULL(type, int32, op); \
REGISTER_SCATTER_ND_FULL(type, int64, op)
#define REGISTER_SCATTER_ND_UPDATE(type) \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
#define REGISTER_SCATTER_ND_MATH(type) \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
#undef REGISTER_SCATTER_ND_MATH
#undef REGISTER_SCATTER_ND_UPDATE
#undef REGISTER_SCATTER_ND_INDEX
#undef REGISTER_SCATTER_ND_FULL
} // namespace functor
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_

View File

@ -0,0 +1,18 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 0
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,18 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 1
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,18 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 2
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,18 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 3
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,18 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 4
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,19 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define CPU_PROVIDED_IXDIM 5
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
#undef CPU_PROVIDED_IXDIM

View File

@ -0,0 +1,320 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
class ScatterNdUpdateOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_ref_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
.Input(FakeInput(variable_ref_type))
.Input(FakeInput(index_type))
.Input(FakeInput(RemoveRefType(variable_ref_type)))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
MakeOp(DT_STRING_REF, DT_INT32);
AddInputFromArray<string>(TensorShape({1}), {"Brain"});
AddInputFromArray<int32>(TensorShape({1}), {0});
AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_STRING, TensorShape({1}));
test::FillValues<string>(&expected, {"TensorFlow"});
test::ExpectTensorEqual<string>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
MakeOp(DT_BOOL_REF, DT_INT32);
AddInputFromArray<bool>(TensorShape({1}), {false});
AddInputFromArray<int32>(TensorShape({1}), {0});
AddInputFromArray<bool>(TensorShape({1}), {true});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
test::FillValues<bool>(&expected, {true});
test::ExpectTensorEqual<bool>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
10002, 0, 0, 0, 777, 778, 779});
test::ExpectTensorEqual<float>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
MakeOp(DT_FLOAT_REF, DT_INT64);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
10002, 0, 0, 0, 777, 778, 779});
test::ExpectTensorEqual<float>(expected, params_tensor);
}
/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({0}), {});
AddInputFromArray<int32>(TensorShape({0}), {});
AddInputFromArray<float>(TensorShape({0}), {});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Output must not have 0 elements, got shape: "))
<< s;
}*/
TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({1}), {3});
AddInputFromArray<float>(TensorShape({1}), {101});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
test::ExpectTensorEqual<float>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
test::ExpectTensorEqual<float>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, HigherRank) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
TF_ASSERT_OK(RunOpKernel());
// Check the new state of the input
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
test::ExpectTensorEqual<float>(expected, params_tensor);
}
TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Invalid indices: [2,0] = [99] is not in [0, 5)"))
<< s;
}
TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("The outermost dimension of updates and indices "
"must match. Got indices.shape [1,3,1], "
"updates.shape [3,3]"))
<< s;
}
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(
TensorShape({3, 4}),
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape[0] + "
"params_shape[IXDIM:], got"))
<< s;
}
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({2, 3}),
{100, 101, 102, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("The outermost dimension of updates and indices "
"must match. Got "))
<< s;
}
class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
public:
virtual void TestBody() {}
void MakeBenchmarkOp(const char* op, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", op)
.Input(FakeInput(DT_FLOAT_REF))
.Input(FakeInput(index_type))
.Input(FakeInput(DT_FLOAT))
.Finalize(node_def()));
TF_CHECK_OK(InitOp());
}
};
template <typename Index>
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
testing::StopTiming();
const int kRows = 10000000 / embedding_size;
std::vector<float> values;
values.reserve(kRows);
for (int i = 0; i < kRows * embedding_size; i++) {
values.push_back(i);
}
const int kNumUpdates = 1000;
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
std::vector<Index> indices;
std::vector<float> updates;
for (int i = 0; i < kNumUpdates; i++) {
indices.push_back(rnd.Uniform(kRows));
for (int j = 0; j < embedding_size; j++) {
updates.push_back(i * 10 + j);
}
}
ScatterNdUpdateBM bm;
bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
updates);
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
iters);
testing::StartTiming();
while (iters-- > 0) {
Status s = bm.RunOpKernel();
}
testing::StopTiming();
}
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
}
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
}
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
}
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
}
BENCHMARK(BM_ScatterNdUpdateInt32)
->Arg(1)
->Arg(10)
->Arg(64)
->Arg(256)
->Arg(1024);
BENCHMARK(BM_ScatterNdUpdateInt64)
->Arg(1)
->Arg(10)
->Arg(64)
->Arg(256)
->Arg(1024);
BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
} // namespace
} // namespace tensorflow

View File

@ -88,16 +88,12 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
CHECK_GE(total, 0);
CHECK_EQ(total, (int64)(Eigen::Index)total);
Eigen::ThreadPoolDevice device(this, this->NumThreads());
device.parallelFor(
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
#else
CHECK(0); // should not be used with the old thread pool
#endif
}
};

View File

@ -57,7 +57,6 @@ TEST(ThreadPool, DoWork) {
}
}
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
TEST(ThreadPool, ParallelFor) {
// Make ParallelFor use as many threads as possible.
int64 kHugeCost = 1 << 30;
@ -80,7 +79,6 @@ TEST(ThreadPool, ParallelFor) {
}
}
}
#endif
static void BM_Sequential(int iters) {
ThreadPool pool(Env::Default(), "test", kNumThreads);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@ -49,11 +50,14 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
return result;
}
// Return the parts of the path, split on the final "/". If there is no
// "/" in the path, the first part of the output is empty and the second
// is the input. If the only "/" in the path is the first character, it is
// the first part of the output.
std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
// Return the parts of the URI, split on the final "/" in the path. If there is
// no "/" in the path, the first part of the output is the scheme and host, and
// the second is the path. If the only "/" in the path is the first character,
// it is included in the first part of the output.
std::pair<StringPiece, StringPiece> SplitPath(StringPiece uri) {
StringPiece scheme, host, path;
ParseURI(uri, &scheme, &host, &path);
auto pos = path.rfind('/');
#ifdef PLATFORM_WINDOWS
if (pos == StringPiece::npos)
@ -61,15 +65,17 @@ std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
#endif
// Handle the case with no '/' in 'path'.
if (pos == StringPiece::npos)
return std::make_pair(StringPiece(path.data(), 0), path);
return std::make_pair(StringPiece(uri.begin(), host.end() - uri.begin()),
path);
// Handle the case with a single leading '/' in 'path'.
if (pos == 0)
return std::make_pair(StringPiece(path.data(), 1),
StringPiece(path.data() + 1, path.size() - 1));
return std::make_pair(
StringPiece(uri.begin(), path.begin() + 1 - uri.begin()),
StringPiece(path.data() + 1, path.size() - 1));
return std::make_pair(
StringPiece(path.data(), pos),
StringPiece(uri.begin(), path.begin() + pos - uri.begin()),
StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
}
@ -185,5 +191,42 @@ string CleanPath(StringPiece unclean_path) {
return path;
}
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
StringPiece* path) {
// 0. Parse scheme
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
// TODO(keveman): Allow "+" and "-" in the scheme.
if (!strings::Scanner(remaining)
.One(strings::Scanner::LETTER)
.Many(strings::Scanner::LETTER_DIGIT_DOT)
.StopCapture()
.OneLiteral("://")
.GetResult(&remaining, scheme)) {
// If there's no scheme, assume the entire string is a path.
*scheme = StringPiece(remaining.begin(), 0);
*host = StringPiece(remaining.begin(), 0);
*path = remaining;
return;
}
// 1. Parse host
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
// No path, so the rest of the URI is the host.
*host = remaining;
*path = StringPiece(remaining.end(), 0);
return;
}
// 2. The rest is the path
*path = remaining;
}
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
if (scheme.empty()) {
return path.ToString();
}
return strings::StrCat(scheme, "://", host, path);
}
} // namespace io
} // namespace tensorflow

View File

@ -74,6 +74,21 @@ StringPiece Extension(StringPiece path);
// string manipulation, completely independent of process state.
string CleanPath(StringPiece path);
// Populates the scheme, host, and path from a URI. scheme, host, and path are
// guaranteed by this function to point into the contents of uri, even if
// empty.
//
// Corner cases:
// - If the URI is invalid, scheme and host are set to empty strings and the
// passed string is assumed to be a path
// - If the URI omits the path (e.g. file://host), then the path is left empty.
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
StringPiece* path);
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
// return the path.
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
} // namespace io
} // namespace tensorflow

View File

@ -45,6 +45,8 @@ TEST(PathTest, IsAbsolutePath) {
}
TEST(PathTest, Dirname) {
EXPECT_EQ("hdfs://127.0.0.1:9000/",
Dirname("hdfs://127.0.0.1:9000/train.csv.tfrecords"));
EXPECT_EQ("/hello", Dirname("/hello/"));
EXPECT_EQ("/", Dirname("/hello"));
EXPECT_EQ("hello", Dirname("hello/world"));
@ -97,5 +99,47 @@ TEST(PathTest, CleanPath) {
EXPECT_EQ("../../bar", CleanPath("foo/../../../bar"));
}
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
do { \
StringPiece u(uri); \
StringPiece s, h, p; \
ParseURI(u, &s, &h, &p); \
EXPECT_EQ(scheme, s.ToString()); \
EXPECT_EQ(host, h.ToString()); \
EXPECT_EQ(path, p.ToString()); \
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
EXPECT_LE(u.begin(), s.begin()); \
EXPECT_GE(u.end(), s.begin()); \
EXPECT_LE(u.begin(), s.end()); \
EXPECT_GE(u.end(), s.end()); \
EXPECT_LE(u.begin(), h.begin()); \
EXPECT_GE(u.end(), h.begin()); \
EXPECT_LE(u.begin(), h.end()); \
EXPECT_GE(u.end(), h.end()); \
EXPECT_LE(u.begin(), p.begin()); \
EXPECT_GE(u.end(), p.begin()); \
EXPECT_LE(u.begin(), p.end()); \
EXPECT_GE(u.end(), p.end()); \
} while (0)
TEST(PathTest, CreateParseURI) {
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
"/usr/local/foo");
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
EXPECT_PARSE_URI("file:", "", "", "file:");
EXPECT_PARSE_URI("file:/", "", "", "file:/");
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
"localhost:8020", "/path/to/file");
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
}
#undef EXPECT_PARSE_URI
} // namespace io
} // namespace tensorflow

View File

@ -4387,6 +4387,83 @@ output_min: This value is copied from input_min.
output_max: This value is copied from input_max.
)Doc");
REGISTER_OP("ScatterNd")
.Input("indices: Tindices")
.Input("updates: T")
.Input("shape: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Doc(
R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices.
This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.
TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.
`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `shape`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `shape`.
`updates` is Tensor of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].
```
The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/ScatterNd1.png" alt>
</div>
In Python, this scatter operation would look like this:
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print sess.run(scatter)
The resulting tensor would look like this:
[0, 11, 0, 10, 9, 0, 0, 12]
We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/ScatterNd2.png" alt>
</div>
In Python, this scatter operation would look like this:
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print sess.run(scatter)
The resulting tensor would look like this:
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref.
shape: A vector. The shape of the resulting tensor.
output: A new tensor with the given shape and updates applied according to the indices.)doc");
REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
@ -4409,6 +4486,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
.Input("gradients: float")
.Input("inputs: float")
.Output("backprops: float")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxArgs operation.
@ -4450,6 +4528,21 @@ REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.SetShapeFn([](InferenceContext* c) {
// gradients and inputs are same size.
ShapeHandle inputs;
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
// min and max are scalars
ShapeHandle min_max;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
c->set_output(0, inputs);
c->set_output(1, min_max);
c->set_output(2, min_max);
return Status::OK();
})
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVars operation.
@ -4503,6 +4596,24 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
ShapeHandle min_max;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
c->set_output(0, inputs);
c->set_output(1, min_max);
c->set_output(2, min_max);
return Status::OK();
})
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.

View File

@ -1533,4 +1533,23 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
INFER_ERROR("must be equal", op, "[5];[4];[?]");
}
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
// Rank check vectors.
INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
// Vectors must match each other, and match last dim of input.
INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
}
} // end namespace tensorflow

View File

@ -24732,6 +24732,321 @@ op {
}
}
}
op {
name: "ScatterNd"
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
input_arg {
name: "shape"
type_attr: "Tindices"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}
op {
name: "ScatterNdAdd"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ScatterNdDiv"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ScatterNdMul"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ScatterNdSub"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ScatterNdUpdate"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
input_arg {
name: "updates"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: true
}
}
}
op {
name: "ScatterSub"
input_arg {

View File

@ -15427,6 +15427,362 @@ op {
summary: "Multiplies sparse updates into a variable reference."
description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] *= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] *= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions multiply.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`."
}
op {
name: "ScatterNd"
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref."
type_attr: "T"
}
input_arg {
name: "shape"
description: "A vector. The shape of the resulting tensor."
type_attr: "Tindices"
}
output_arg {
name: "output"
description: "A new tensor with the given shape and updates applied according to the indices."
type_attr: "T"
}
attr {
name: "T"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
summary: "Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices."
description: "This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.\n\nTODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.\n\n`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `shape`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `shape`.\n\n`updates` is Tensor of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].\n```\n\nThe simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd1.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n shape = tf.constant([8])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [0, 11, 0, 10, 9, 0, 0, 12]\n\nWe can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd2.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[0], [2]])\n updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]],\n [[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]]])\n shape = tf.constant([4, 4, 4])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]"
}
op {
name: "ScatterNdAdd"
input_arg {
name: "ref"
description: "A mutable Tensor. Should be from a Variable node."
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
type_attr: "T"
}
output_arg {
name: "output_ref"
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`."
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n add = tf.scatter_nd_add(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(add)\n\nThe resulting update to ref would look like this:\n\n [1, 13, 3, 14, 14, 6, 7, 20]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
}
op {
name: "ScatterNdDiv"
input_arg {
name: "ref"
description: "A mutable Tensor. Should be from a Variable node."
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
type_attr: "T"
}
output_arg {
name: "output_ref"
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:\n\n ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([2, 3, 4, 5])\n sub = tf.scatter_nd_div(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [10, 5, 30, 13, 25, 60, 70, 16]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
}
op {
name: "ScatterNdMul"
input_arg {
name: "ref"
description: "A mutable Tensor. Should be from a Variable node."
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
type_attr: "T"
}
output_arg {
name: "output_ref"
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_mul(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, 22, 3, 40, 45, 6, 7, 96]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
}
op {
name: "ScatterNdSub"
input_arg {
name: "ref"
description: "A mutable Tensor. Should be from a Variable node."
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
type_attr: "T"
}
output_arg {
name: "output_ref"
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: false
}
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_sub(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, -9, 3, -6, -4, 6, 7, -4]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
}
op {
name: "ScatterNdUpdate"
input_arg {
name: "ref"
description: "A mutable Tensor. Should be from a Variable node."
type_attr: "T"
is_ref: true
}
input_arg {
name: "indices"
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
type_attr: "Tindices"
}
input_arg {
name: "updates"
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
type_attr: "T"
}
output_arg {
name: "output_ref"
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: true
}
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse `updates` to individual values or slices within a given variable according to `indices`."
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1] ,[7]])\n updates = tf.constant([9, 10, 11, 12])\n update = tf.scatter_nd_update(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(update)\n\nThe resulting update to ref would look like this:\n\n [1, 11, 3, 10, 9, 6, 7, 12]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
}
op {
name: "ScatterSub"
input_arg {

View File

@ -445,6 +445,241 @@ use_locking: If True, the operation will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
REGISTER_OP("ScatterNdUpdate")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Doc(
R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
print sess.run(update)
The resulting update to ref would look like this:
[1, 11, 3, 10, 9, 6, 7, 12]
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
ref: A mutable Tensor. Should be from a Variable node.
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Doc(
R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
add = tf.scatter_nd_add(ref, indices, updates)
with tf.Session() as sess:
print sess.run(add)
The resulting update to ref would look like this:
[1, 13, 3, 14, 14, 6, 7, 20]
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
ref: A mutable Tensor. Should be from a Variable node.
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
REGISTER_OP("ScatterNdSub")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Doc(
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
sub = tf.scatter_nd_sub(ref, indices, updates)
with tf.Session() as sess:
print sess.run(sub)
The resulting update to ref would look like this:
[1, -9, 3, -6, -4, 6, 7, -4]
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
ref: A mutable Tensor. Should be from a Variable node.
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
REGISTER_OP("ScatterNdMul")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Doc(
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
sub = tf.scatter_nd_mul(ref, indices, updates)
with tf.Session() as sess:
print sess.run(sub)
The resulting update to ref would look like this:
[1, 22, 3, 40, 45, 6, 7, 96]
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
ref: A mutable Tensor. Should be from a Variable node.
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
REGISTER_OP("ScatterNdDiv")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Doc(
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:
ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([2, 3, 4, 5])
sub = tf.scatter_nd_div(ref, indices, updates)
with tf.Session() as sess:
print sess.run(sub)
The resulting update to ref would look like this:
[10, 5, 30, 13, 25, 60, 70, 16]
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
ref: A mutable Tensor. Should be from a Variable node.
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)")
.Output("output: T")

View File

@ -81,7 +81,7 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::Internal("bucket and object cannot be null.");
}
StringPiece scheme, bucketp, objectp;
ParseURI(fname, &scheme, &bucketp, &objectp);
io::ParseURI(fname, &scheme, &bucketp, &objectp);
if (scheme != "gs") {
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
fname);

View File

@ -76,16 +76,24 @@ cc_library(
name = "platformlib",
copts = tf_copts(),
deps = [
":gif",
":jpeg",
"//tensorflow/core:protos_cc",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
"@gif_archive//:gif",
"@highwayhash//:sip_hash",
"@jpeg_archive//:jpeg",
"@png_archive//:png",
],
)
cc_library(
name = "gif",
copts = tf_copts(),
deps = [
"@gif_archive//:gif",
],
)
cc_library(
name = "jpeg",
copts = tf_copts(),

View File

@ -70,7 +70,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
StringPiece scheme, host, path;
ParseURI(fname, &scheme, &host, &path);
io::ParseURI(fname, &scheme, &host, &path);
FileSystem* file_system = file_system_registry_->Lookup(scheme.ToString());
if (!file_system) {
return errors::Unimplemented("File system scheme ", scheme,

View File

@ -229,35 +229,6 @@ TEST_F(DefaultEnvTest, LocalFileSystem) {
}
}
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
do { \
StringPiece s, h, p; \
ParseURI(uri, &s, &h, &p); \
EXPECT_EQ(scheme, s.ToString()); \
EXPECT_EQ(host, h.ToString()); \
EXPECT_EQ(path, p.ToString()); \
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
} while (0)
TEST_F(DefaultEnvTest, CreateParseURI) {
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
"/usr/local/foo");
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
EXPECT_PARSE_URI("file:", "", "", "file:");
EXPECT_PARSE_URI("file:/", "", "", "file:/");
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
"localhost:8020", "/path/to/file");
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
}
#undef EXPECT_PARSE_URI
TEST_F(DefaultEnvTest, SleepForMicroseconds) {
const int64 start = env_->NowMicros();
const int64 sleep_time = 1e6 + 5e5;
@ -274,14 +245,14 @@ class TmpDirFileSystem : public NullFileSystem {
public:
bool FileExists(const string& dir) override {
StringPiece scheme, host, path;
ParseURI(dir, &scheme, &host, &path);
io::ParseURI(dir, &scheme, &host, &path);
if (path.empty()) return false;
return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
}
Status CreateDir(const string& dir) override {
StringPiece scheme, host, path;
ParseURI(dir, &scheme, &host, &path);
io::ParseURI(dir, &scheme, &host, &path);
if (scheme != "tmpdirfs") {
return errors::FailedPrecondition("scheme must be tmpdirfs");
}

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
@ -79,43 +78,6 @@ WritableFile::~WritableFile() {}
FileSystemRegistry::~FileSystemRegistry() {}
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
StringPiece* path) {
// 0. Parse scheme
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
// TODO(keveman): Allow "+" and "-" in the scheme.
if (!strings::Scanner(remaining)
.One(strings::Scanner::LETTER)
.Many(strings::Scanner::LETTER_DIGIT_DOT)
.StopCapture()
.OneLiteral("://")
.GetResult(&remaining, scheme)) {
// If there's no scheme, assume the entire string is a path.
scheme->clear();
host->clear();
*path = remaining;
return;
}
// 1. Parse host
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
// No path, so the rest of the URI is the host.
*host = remaining;
path->clear();
return;
}
// 2. The rest is the path
*path = remaining;
}
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
if (scheme.empty()) {
return path.ToString();
}
return strings::StrCat(scheme, "://", host, path);
}
Status FileSystem::GetMatchingPaths(const string& pattern,
std::vector<string>* results) {
results->clear();
@ -237,9 +199,9 @@ Status FileSystem::DeleteRecursively(const string& dirname,
Status FileSystem::RecursivelyCreateDir(const string& dirname) {
StringPiece scheme, host, remaining_dir;
ParseURI(dirname, &scheme, &host, &remaining_dir);
io::ParseURI(dirname, &scheme, &host, &remaining_dir);
std::vector<StringPiece> sub_dirs;
while (!FileExists(CreateURI(scheme, host, remaining_dir)) &&
while (!FileExists(io::CreateURI(scheme, host, remaining_dir)) &&
!remaining_dir.empty()) {
// Basename returns "" for / ending dirs.
if (!remaining_dir.ends_with("/")) {
@ -255,7 +217,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
string built_path = remaining_dir.ToString();
for (const StringPiece sub_dir : sub_dirs) {
built_path = io::JoinPath(built_path, sub_dir);
TF_RETURN_IF_ERROR(CreateDir(CreateURI(scheme, host, built_path)));
TF_RETURN_IF_ERROR(CreateDir(io::CreateURI(scheme, host, built_path)));
}
return Status::OK();
}

View File

@ -287,19 +287,6 @@ class FileSystemRegistry {
std::vector<string>* schemes) = 0;
};
// Populates the scheme, host, and path from a URI.
//
// Corner cases:
// - If the URI is invalid, scheme and host are set to empty strings and the
// passed string is assumed to be a path
// - If the URI omits the path (e.g. file://host), then the path is left empty.
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
StringPiece* path);
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
// return the path.
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_

View File

@ -112,7 +112,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
void ParsePath(const string& name, string* parsed_path) {
StringPiece scheme, host, path;
ParseURI(name, &scheme, &host, &path);
io::ParseURI(name, &scheme, &host, &path);
ASSERT_EQ(scheme, "ipfs");
ASSERT_EQ(host, "solarsystem");
path.Consume("/");

View File

@ -126,7 +126,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
TF_RETURN_IF_ERROR(hdfs_->status());
StringPiece scheme, namenode, path;
ParseURI(fname, &scheme, &namenode, &path);
io::ParseURI(fname, &scheme, &namenode, &path);
const string nn = namenode.ToString();
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
@ -144,7 +144,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
string HadoopFileSystem::TranslateName(const string& name) const {
StringPiece scheme, namenode, path;
ParseURI(name, &scheme, &namenode, &path);
io::ParseURI(name, &scheme, &namenode, &path);
return path.ToString();
}

View File

@ -120,7 +120,8 @@ class PosixEnv : public Env {
symbol);
}
string FormatLibraryFileName(const string& name, const string& version) {
string FormatLibraryFileName(const string& name,
const string& version) override {
return tensorflow::internal::FormatLibraryFileName(name, version);
}
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
#define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
@ -63,7 +64,7 @@ class LocalPosixFileSystem : public PosixFileSystem {
public:
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
ParseURI(name, &scheme, &host, &path);
io::ParseURI(name, &scheme, &host, &path);
return path.ToString();
}
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
#define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/file_system.h"
#ifdef PLATFORM_WINDOWS
@ -68,7 +69,7 @@ class LocalWinFileSystem : public WindowsFileSystem {
public:
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
ParseURI(name, &scheme, &host, &path);
io::ParseURI(name, &scheme, &host, &path);
return path.ToString();
}
};

View File

@ -122,6 +122,10 @@ message RunStepRequest {
// Options for the run call.
RunOptions options = 5;
// Partial run handle (optional). If specified, this will be a partial run
// execution, run up to the specified fetches.
string partial_run_handle = 6;
}
message RunStepResponse {
@ -133,6 +137,42 @@ message RunStepResponse {
RunMetadata metadata = 2;
}
////////////////////////////////////////////////////////////////////////////////
//
// PartialRunSetup method request/response protos.
//
// The caller should provide the future partial run feeds, fetches, and targets.
// Then the caller can use RunStepRequest with is_partial set to make partial
// run calls.
//
////////////////////////////////////////////////////////////////////////////////
message PartialRunSetupRequest {
// REQUIRED: session_handle must be returned by a CreateSession call
// to the same master service.
string session_handle = 1;
// Tensors to be fed in future steps.
repeated string feed = 2;
// Fetches. A list of tensor names. The caller expects a tensor to be returned
// for each fetch[i] (see RunStepResponse.tensor), for corresponding partial
// RunStepRequests. The order of specified fetches does not change the
// execution order.
repeated string fetch = 3;
// Target Nodes. A list of node names. The named nodes will be run in future
// steps, but their outputs will not be fetched.
repeated string target = 4;
}
message PartialRunSetupResponse {
// The unique handle corresponding to the ongoing partial run call setup by
// the invocation to PartialRunSetup. This handle may be passed to
// RunStepRequest to send and receive tensors for this partial run.
string partial_run_handle = 1;
}
////////////////////////////////////////////////////////////////////////////////
//
// CloseSession method request/response protos.

View File

@ -91,6 +91,9 @@ service MasterService {
// Extends a session.
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
// Prepares future partial run calls.
rpc PartialRunSetup(PartialRunSetupRequest) returns (PartialRunSetupResponse);
// Drives the graph computation.
rpc RunStep(RunStepRequest) returns (RunStepResponse);

View File

@ -343,7 +343,11 @@ Status BundleWriter::Finish() {
status_ = env_->NewWritableFile(MetaFilename(prefix_), &file);
if (!status_.ok()) return status_;
{
table::TableBuilder builder(table::Options(), file.get());
// N.B.: the default use of Snappy compression may not be supported on all
// platforms (e.g. Android). The metadata file is small, so this is fine.
table::Options options;
options.compression = table::kNoCompression;
table::TableBuilder builder(options, file.get());
// Header entry.
BundleHeaderProto header;
header.set_num_shards(1);

View File

@ -31,12 +31,10 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
work(0, total);
return;
}
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
if (max_parallelism >= workers->NumThreads()) {
workers->ParallelFor(total, cost_per_unit, work);
return;
}
#endif
cost_per_unit = std::max(1LL, cost_per_unit);
// We shard [0, total) into "num_shards" shards.
// 1 <= num_shards <= num worker threads

View File

@ -1,34 +0,0 @@
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<org.tensorflow.demo.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:layout_alignParentStart="true"
android:layout_alignParentTop="true" />
<org.tensorflow.demo.RecognitionScoreView
android:id="@+id/results"
android:layout_width="match_parent"
android:layout_height="112dp"
android:layout_alignParentTop="true" />
</RelativeLayout>

View File

@ -22,11 +22,17 @@
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true" />
<org.tensorflow.demo.RecognitionScoreView
android:id="@+id/results"
android:layout_width="match_parent"
android:layout_height="112dp"
android:layout_alignParentTop="true" />
<org.tensorflow.demo.OverlayView
android:id="@+id/overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentBottom="true" />
</RelativeLayout>

View File

@ -20,32 +20,72 @@ import android.Manifest;
import android.app.Activity;
import android.app.Fragment;
import android.content.pm.PackageManager;
import android.media.Image.Plane;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Build;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Size;
import android.view.MotionEvent;
import android.view.WindowManager;
import android.widget.Toast;
import java.nio.ByteBuffer;
import org.tensorflow.demo.env.Logger;
public abstract class CameraActivity extends Activity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
public abstract class CameraActivity extends Activity {
private static final int PERMISSIONS_REQUEST = 1;
private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
private boolean debug = false;
private Handler handler;
private HandlerThread handlerThread;
@Override
protected void onCreate(final Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
super.onCreate(null);
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
setContentView(R.layout.activity_camera);
if (hasPermission()) {
if (null == savedInstanceState) {
setFragment();
}
setFragment();
} else {
requestPermission();
}
}
@Override
public synchronized void onResume() {
super.onResume();
handlerThread = new HandlerThread("inference");
handlerThread.start();
handler = new Handler(handlerThread.getLooper());
}
@Override
public synchronized void onPause() {
super.onPause();
handlerThread.quitSafely();
try {
handlerThread.join();
handlerThread = null;
handler = null;
} catch (final InterruptedException e) {
LOGGER.e(e, "Exception!");
}
}
protected synchronized void runInBackground(final Runnable r) {
if (handler != null) {
handler.post(r);
}
}
@Override
@ -82,11 +122,47 @@ public abstract class CameraActivity extends Activity {
}
protected void setFragment() {
final Fragment fragment = CameraConnectionFragment.newInstance(
new CameraConnectionFragment.ConnectionCallback(){
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
CameraActivity.this.onPreviewSizeChosen(size, rotation);
}
},
this, getLayoutId(), getDesiredPreviewFrameSize());
getFragmentManager()
.beginTransaction()
.replace(R.id.container, createFragment())
.replace(R.id.container, fragment)
.commit();
}
protected abstract Fragment createFragment();
protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
// Because of the variable row stride it's not possible to know in
// advance the actual necessary dimensions of the yuv planes.
for (int i = 0; i < planes.length; ++i) {
final ByteBuffer buffer = planes[i].getBuffer();
if (yuvBytes[i] == null) {
LOGGER.i("Initializing buffer %d at size %d", i, buffer.capacity());
yuvBytes[i] = new byte[buffer.capacity()];
}
buffer.get(yuvBytes[i]);
}
}
@Override
public boolean onTouchEvent(final MotionEvent event) {
if (event.getAction() == MotionEvent.ACTION_DOWN) {
debug = !debug;
}
return false;
}
public boolean isDebug() {
return debug;
}
protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
protected abstract int getLayoutId();
protected abstract int getDesiredPreviewFrameSize();
}

View File

@ -38,6 +38,7 @@ import android.hardware.camera2.CaptureResult;
import android.hardware.camera2.TotalCaptureResult;
import android.hardware.camera2.params.StreamConfigurationMap;
import android.media.ImageReader;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
@ -49,9 +50,6 @@ import android.view.TextureView;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Toast;
import org.tensorflow.demo.env.Logger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -59,6 +57,7 @@ import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import org.tensorflow.demo.env.Logger;
public class CameraConnectionFragment extends Fragment {
private static final Logger LOGGER = new Logger();
@ -69,8 +68,6 @@ public class CameraConnectionFragment extends Fragment {
*/
private static final int MINIMUM_PREVIEW_SIZE = 320;
private ResultsView resultsView;
/**
* Conversion from screen rotation to JPEG orientation.
*/
@ -111,6 +108,14 @@ public class CameraConnectionFragment extends Fragment {
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
};
/**
* Callback for Activities to use to initialize their data once the
* selected preview size is known.
*/
public interface ConnectionCallback {
void onPreviewSizeChosen(Size size, int cameraRotation);
}
/**
* ID of the current {@link CameraDevice}.
*/
@ -184,16 +189,6 @@ public class CameraConnectionFragment extends Fragment {
*/
private Handler backgroundHandler;
/**
* An additional thread for running inference so as not to block the camera.
*/
private HandlerThread inferenceThread;
/**
* A {@link Handler} for running tasks in the background.
*/
private Handler inferenceHandler;
/**
* An {@link ImageReader} that handles preview frame capture.
*/
@ -215,9 +210,10 @@ public class CameraConnectionFragment extends Fragment {
private final Semaphore cameraOpenCloseLock = new Semaphore(1);
/**
* A {@link Classifier} object wrapping TensorFlow to pass frames to.
* A {@link OnImageAvailableListener} to receive frames as they are available.
*/
private final Classifier classifier;
private final OnImageAvailableListener imageListener;
/**
* The input size in pixels desired by TensorFlow (width and height of a square bitmap).
*/
@ -228,9 +224,15 @@ public class CameraConnectionFragment extends Fragment {
*/
private final int layout;
private final ConnectionCallback cameraConnectionCallback;
private CameraConnectionFragment(
final Classifier classifier, final int layout, final int inputSize) {
this.classifier = classifier;
final ConnectionCallback connectionCallback,
final OnImageAvailableListener imageListener,
final int layout, final int inputSize) {
this.cameraConnectionCallback = connectionCallback;
this.imageListener = imageListener;
this.layout = layout;
this.inputSize = inputSize;
}
@ -268,8 +270,12 @@ public class CameraConnectionFragment extends Fragment {
final Size[] choices, final int width, final int height, final Size aspectRatio) {
// Collect the supported resolutions that are at least as big as the preview Surface
final List<Size> bigEnough = new ArrayList<Size>();
final int minWidth = Math.max(width, MINIMUM_PREVIEW_SIZE);
final int minHeight = Math.max(height, MINIMUM_PREVIEW_SIZE);
for (final Size option : choices) {
if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) {
if (option.getHeight() >= minHeight && option.getWidth() >= minWidth) {
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
bigEnough.add(option);
} else {
@ -289,8 +295,9 @@ public class CameraConnectionFragment extends Fragment {
}
public static CameraConnectionFragment newInstance(
final Classifier classifier, final int layout, final int inputSize) {
return new CameraConnectionFragment(classifier, layout, inputSize);
final ConnectionCallback callback,
final OnImageAvailableListener imageListener, final int layout, final int inputSize) {
return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
}
@Override
@ -302,7 +309,6 @@ public class CameraConnectionFragment extends Fragment {
@Override
public void onViewCreated(final View view, final Bundle savedInstanceState) {
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
resultsView = (ResultsView) view.findViewById(R.id.results);
}
@Override
@ -371,7 +377,8 @@ public class CameraConnectionFragment extends Fragment {
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
// garbage capture data.
previewSize =
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height, largest);
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
inputSize, inputSize, largest);
// We fit the aspect ratio of TextureView to the size of preview we picked.
final int orientation = getResources().getConfiguration().orientation;
@ -382,6 +389,8 @@ public class CameraConnectionFragment extends Fragment {
}
CameraConnectionFragment.this.cameraId = cameraId;
cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
return;
}
} catch (final CameraAccessException e) {
@ -446,10 +455,6 @@ public class CameraConnectionFragment extends Fragment {
backgroundThread = new HandlerThread("ImageListener");
backgroundThread.start();
backgroundHandler = new Handler(backgroundThread.getLooper());
inferenceThread = new HandlerThread("InferenceThread");
inferenceThread.start();
inferenceHandler = new Handler(inferenceThread.getLooper());
}
/**
@ -457,22 +462,15 @@ public class CameraConnectionFragment extends Fragment {
*/
private void stopBackgroundThread() {
backgroundThread.quitSafely();
inferenceThread.quitSafely();
try {
backgroundThread.join();
backgroundThread = null;
backgroundHandler = null;
inferenceThread.join();
inferenceThread = null;
inferenceThread = null;
} catch (final InterruptedException e) {
LOGGER.e(e, "Exception!");
}
}
private final TensorFlowImageListener tfPreviewListener = new TensorFlowImageListener();
private final CameraCaptureSession.CaptureCallback captureCallback =
new CameraCaptureSession.CaptureCallback() {
@Override
@ -513,7 +511,7 @@ public class CameraConnectionFragment extends Fragment {
ImageReader.newInstance(
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
previewReader.setOnImageAvailableListener(tfPreviewListener, backgroundHandler);
previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
previewRequestBuilder.addTarget(previewReader.getSurface());
// Here, we create a CameraCaptureSession for camera preview.
@ -557,11 +555,6 @@ public class CameraConnectionFragment extends Fragment {
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
}
LOGGER.i("Getting assets.");
tfPreviewListener.initialize(
classifier, resultsView, inputSize, inferenceHandler, sensorOrientation);
LOGGER.i("TensorFlow initialized.");
}
/**

View File

@ -16,12 +16,29 @@
package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.media.Image;
import android.media.Image.Plane;
import android.media.ImageReader;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.os.Trace;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import java.io.IOException;
import android.app.Fragment;
import java.util.List;
import java.util.Vector;
import org.tensorflow.demo.OverlayView.DrawCallback;
import org.tensorflow.demo.env.BorderedText;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
public class ClassifierActivity extends CameraActivity {
public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
// These are the settings for the original v1 Inception model. If you want to
@ -41,9 +58,58 @@ public class ClassifierActivity extends CameraActivity {
private static final String LABEL_FILE =
"file:///android_asset/imagenet_comp_graph_label_strings.txt";
private static final boolean SAVE_PREVIEW_BITMAP = false;
private static final boolean MAINTAIN_ASPECT = true;
private TensorFlowImageClassifier classifier;
private Integer sensorOrientation;
private int previewWidth = 0;
private int previewHeight = 0;
private byte[][] yuvBytes;
private int[] rgbBytes = null;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap;
private boolean computing = false;
private long timestamp = 0;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private ResultsView resultsView;
private OverlayView overlayView;
private BorderedText borderedText;
private long lastProcessingTimeMs;
@Override
protected Fragment createFragment() {
final TensorFlowImageClassifier classifier = new TensorFlowImageClassifier();
protected int getLayoutId() {
return R.layout.camera_connection_fragment;
}
@Override
protected int getDesiredPreviewFrameSize() {
return INPUT_SIZE;
}
private static final float TEXT_SIZE_DIP = 18;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx = TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
classifier = new TensorFlowImageClassifier();
try {
classifier.initializeTensorFlow(
getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
@ -52,7 +118,151 @@ public class ClassifierActivity extends CameraActivity {
LOGGER.e(e, "Exception!");
}
return CameraConnectionFragment.newInstance(
classifier, R.layout.camera_connection_fragment, INPUT_SIZE);
overlayView = (OverlayView) findViewById(R.id.overlay);
resultsView = (ResultsView) findViewById(R.id.results);
previewWidth = size.getWidth();
previewHeight = size.getHeight();
final Display display = getWindowManager().getDefaultDisplay();
final int screenOrientation = display.getRotation();
LOGGER.i("Sensor orientation: %d, Screen orientation: %d",
rotation, screenOrientation);
sensorOrientation = rotation + screenOrientation;
if (sensorOrientation % 180 == 90) {
overlayView.setAspectRatio(size.getHeight(), size.getWidth());
} else {
overlayView.setAspectRatio(size.getWidth(), size.getHeight());
}
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbBytes = new int[previewWidth * previewHeight];
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
INPUT_SIZE, INPUT_SIZE,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
yuvBytes = new byte[3][];
overlayView.addCallback(new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
renderDebug(canvas);
}
});
}
@Override
public void onImageAvailable(final ImageReader reader) {
Image image = null;
++timestamp;
try {
image = reader.acquireLatestImage();
if (image == null) {
return;
}
if (computing) {
image.close();
return;
}
computing = true;
Trace.beginSection("imageAvailable");
final Plane[] planes = image.getPlanes();
fillBytes(planes, yuvBytes);
final int yRowStride = planes[0].getRowStride();
final int uvRowStride = planes[1].getRowStride();
final int uvPixelStride = planes[1].getPixelStride();
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
rgbBytes,
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
false);
image.close();
} catch (final Exception e) {
if (image != null) {
image.close();
}
LOGGER.e(e, "Exception!");
Trace.endSection();
return;
}
rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
runInBackground(
new Runnable() {
@Override
public void run() {
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
resultsView.setResults(results);
overlayView.postInvalidate();
computing = false;
}
});
Trace.endSection();
}
private void renderDebug(Canvas canvas) {
if (!isDebug()) {
return;
}
final Bitmap copy = cropCopyBitmap;
if (copy != null) {
final Matrix matrix = new Matrix();
final float scaleFactor = 2;
matrix.postScale(scaleFactor, scaleFactor);
matrix.postTranslate(
canvas.getWidth() - copy.getWidth() * scaleFactor,
canvas.getHeight() - copy.getHeight() * scaleFactor);
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>();
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
int lineNum = 0;
for (final String line : lines) {
borderedText.drawText(canvas, 10,
canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, line);
++lineNum;
}
}
}
}

View File

@ -0,0 +1,94 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.Context;
import android.graphics.Canvas;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;
import android.view.View.MeasureSpec;
import java.util.LinkedList;
import java.util.List;
/**
* A simple View providing a render callback to other classes.
*/
public class OverlayView extends View {
public OverlayView(final Context context, final AttributeSet attrs) {
super(context, attrs);
}
/**
* Interface defining the callback for client classes.
*/
public interface DrawCallback {
public void drawCallback(final Canvas canvas);
}
private int ratioWidth;
private int ratioHeight;
private boolean debug;
private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
@Override
public boolean onTouchEvent(final MotionEvent e) {
super.onTouchEvent(e);
if (e.getAction() == MotionEvent.ACTION_DOWN) {
debug = !debug;
}
return false;
}
public void addCallback(final DrawCallback callback) {
callbacks.add(callback);
}
@Override
public synchronized void draw(final Canvas canvas) {
for (final DrawCallback callback : callbacks) {
callback.drawCallback(canvas);
}
}
public void setAspectRatio(final int width, final int height) {
if (width < 0 || height < 0) {
throw new IllegalArgumentException("Size cannot be negative.");
}
ratioWidth = width;
ratioHeight = height;
requestLayout();
}
@Override
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
final int width = MeasureSpec.getSize(widthMeasureSpec);
final int height = MeasureSpec.getSize(heightMeasureSpec);
if (0 == ratioWidth || 0 == ratioHeight) {
setMeasuredDimension(width, height);
} else {
if (width < height * ratioWidth / ratioHeight) {
setMeasuredDimension(width, width * ratioHeight / ratioWidth);
} else {
setMeasuredDimension(height * ratioWidth / ratioHeight, height);
}
}
}
}

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Paint.Align;
import android.graphics.Paint.Style;
import android.graphics.Rect;
/**
* A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
*/
public class BorderedText {
private final Paint interiorPaint;
private final Paint exteriorPaint;
private final float textSize;
/**
* Creates a left-aligned bordered text object with a white interior, and a black exterior with
* the specified text size.
*
* @param textSize text size in pixels
*/
public BorderedText(final float textSize) {
this(Color.WHITE, Color.BLACK, textSize);
}
/**
* Create a bordered text object with the specified interior and exterior colors, text size and
* alignment.
*
* @param interiorColor the interior text color
* @param exteriorColor the exterior text color
* @param textSize text size in pixels
*/
public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
interiorPaint = new Paint();
interiorPaint.setTextSize(textSize);
interiorPaint.setColor(interiorColor);
interiorPaint.setStyle(Style.FILL);
interiorPaint.setAntiAlias(false);
interiorPaint.setAlpha(255);
exteriorPaint = new Paint();
exteriorPaint.setTextSize(textSize);
exteriorPaint.setColor(exteriorColor);
exteriorPaint.setStyle(Style.FILL_AND_STROKE);
exteriorPaint.setStrokeWidth(textSize / 8);
exteriorPaint.setAntiAlias(false);
exteriorPaint.setAlpha(255);
this.textSize = textSize;
}
public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
/*
if (widths == null || widths.length < text.length()) {
widths = new float[text.length()];
positions = new float[text.length() * 2];
}
exteriorPaint.getTextWidths(text, widths);
float lastPosX = posX;
for (int i = 0; i < widths.length; ++i) {
positions[i * 2] = lastPosX;
positions[i * 2 + 1] = posY;
lastPosX += widths[i];
}
*/
//canvas.drawPosText(text, positions, exteriorPaint);
//canvas.drawPosText(text, positions, exteriorPaint);
canvas.drawText(text, posX, posY, exteriorPaint);
canvas.drawText(text, posX, posY, interiorPaint);
}
public void setInteriorColor(final int color) {
interiorPaint.setColor(color);
}
public void setExteriorColor(final int color) {
exteriorPaint.setColor(color);
}
public float getTextSize() {
return textSize;
}
public void setAlpha(final int alpha) {
interiorPaint.setAlpha(alpha);
exteriorPaint.setAlpha(alpha);
}
public void getTextBounds(
final String line, final int index, final int count, final Rect lineBounds) {
interiorPaint.getTextBounds(line, index, count, lineBounds);
}
public void setTextAlign(final Align align) {
interiorPaint.setTextAlign(align);
exteriorPaint.setTextAlign(align);
}
}

Some files were not shown because too many files have changed in this diff Show More