commit
76b2c0630b
@ -62,8 +62,6 @@ cc_library(
|
||||
# This define (mostly) guarantees we don't link any problematic
|
||||
# code. We use it, but we do not rely on it, as evidenced above.
|
||||
"EIGEN_MPL2_ONLY",
|
||||
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
|
||||
# eigen_initialize.cc file and alwayslink=1.
|
||||
],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -105,6 +105,7 @@ filegroup(
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/integrate:all_files",
|
||||
"//tensorflow/contrib/layers:all_files",
|
||||
"//tensorflow/contrib/layers/kernels:all_files",
|
||||
"//tensorflow/contrib/learn:all_files",
|
||||
@ -148,7 +149,6 @@ filegroup(
|
||||
"//tensorflow/examples/image_retraining:all_files",
|
||||
"//tensorflow/examples/label_image:all_files",
|
||||
"//tensorflow/examples/learn:all_files",
|
||||
"//tensorflow/examples/skflow:all_files",
|
||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||
|
@ -264,6 +264,36 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nn_grad",
|
||||
srcs = ["gradients/nn_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients_nn_grad_test",
|
||||
srcs = ["gradients/nn_grad_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":grad_testutil",
|
||||
":gradient_checker",
|
||||
":nn_grad",
|
||||
":testutil",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
|
@ -110,20 +110,15 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
Status ComputeGradientErrorInternal(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape,
|
||||
const ops::Output& y,
|
||||
const TensorShape& y_shape, Tensor* x_data,
|
||||
T* max_error) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
|
||||
// Initialize theoretical Jacobian to zeros.
|
||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||
@ -131,7 +126,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
||||
scope, x, x_shape, *x_data, y, y_shape, &jacobian_t));
|
||||
|
||||
// Initialize numeric Jacobian to zeros.
|
||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||
@ -140,7 +135,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
||||
scope, x, x_shape, y, y_shape, 1e-3, x_data, &jacobian_n));
|
||||
|
||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
@ -154,10 +149,39 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, x, x_shape, y, y_shape, &x_data,
|
||||
max_error);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const Tensor& x_init_value, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
// Initialize 'x_data' from 'x_init_value'.
|
||||
Tensor x_data(x_init_value);
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, x, x_data.shape(), y, y_shape,
|
||||
&x_data, max_error);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error); \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const Tensor& x_init_value, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error);
|
||||
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||
|
@ -30,6 +30,12 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
// Overload of ComputeGradientError which takes an initial value for 'x'.
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const Tensor& x_init_value, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
|
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
Status SoftmaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// Softmax gradient function.
|
||||
// p = softmax(x) maps from [batch, n] to [batch, m]
|
||||
// dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
|
||||
// [ ... ... ]
|
||||
// [dpm-1/dx0 ... dpm-1/dxn-1]
|
||||
// dL/dx = dp/dx * dL/dy
|
||||
//
|
||||
// Using alternative formula:
|
||||
// dL/dx = dL/dy * y - sum(dL/dy * y) * y
|
||||
// = (dL/dy - sum(dL/dy * y)) * y
|
||||
auto y = op.output(0);
|
||||
auto dyy = Mul(scope, grad_inputs[0], y);
|
||||
auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
|
||||
auto sub = Sub(scope, grad_inputs[0], sum);
|
||||
auto dx = Mul(scope, sub, y);
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
|
||||
|
||||
Status ReluGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = ReluGrad(scope, grad_inputs[0], op.input(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
|
||||
|
||||
Status Relu6GradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = Relu6Grad(scope, grad_inputs[0], op.input(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
|
||||
|
||||
Status EluGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = EluGrad(scope, grad_inputs[0], op.output(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
class NNGradTest : public ::testing::Test {
|
||||
protected:
|
||||
NNGradTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(NNGradTest, SoftmaxGrad) {
|
||||
TensorShape shape({32, 10});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Softmax(scope_, x);
|
||||
RunTest(x, shape, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, ReluGrad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Relu(scope_, x);
|
||||
// Avoid input values where ReLU gradient is not well defined (around zero).
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, Relu6Grad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Relu6(scope_, x);
|
||||
// Avoid input values where ReLU gradient is not well defined (around zero
|
||||
// and six).
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, EluGrad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Elu(scope_, x);
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -23,6 +23,7 @@ py_library(
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
|
@ -76,7 +76,7 @@ def build_split_apply_merge_model():
|
||||
|
||||
# REINFORCE forward step
|
||||
route_selection = st.StochasticTensor(
|
||||
distributions.Categorical, logits=logits)
|
||||
distributions.Categorical(logits=logits))
|
||||
|
||||
# Accessing route_selection as a Tensor below forces a sample of
|
||||
# the Categorical distribution based on its logits.
|
||||
|
@ -22,6 +22,7 @@ import tensorflow as tf
|
||||
|
||||
st = tf.contrib.bayesflow.stochastic_tensor
|
||||
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
||||
dists = tf.contrib.distributions
|
||||
|
||||
|
||||
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
@ -31,7 +32,7 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
self._final_loss = tf.constant(3.2)
|
||||
|
||||
def _testScoreFunction(self, loss_fn, expected):
|
||||
x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
|
||||
x = st.StochasticTensor(dists.Bernoulli(p=self._p), loss_fn=loss_fn)
|
||||
sf = x.loss(self._final_loss)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
@ -62,8 +63,8 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
def testScoreFunctionWithMeanBaseline(self):
|
||||
ema_decay = 0.8
|
||||
num_steps = 6
|
||||
x = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
x = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
sf = x.loss(self._final_loss)
|
||||
@ -98,12 +99,12 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
|
||||
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
||||
ema_decay = 0.8
|
||||
x = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
x = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
y = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
y = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
sf_x = x.loss(self._final_loss)
|
||||
|
@ -39,9 +39,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = [0.0, 0.1, 0.2]
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
likelihood = st.StochasticTensor(
|
||||
distributions.Normal, mu=prior, sigma=sigma)
|
||||
distributions.Normal(mu=prior, sigma=sigma))
|
||||
self.assertTrue(prior.distribution.is_reparameterized)
|
||||
self.assertTrue(likelihood.distribution.is_reparameterized)
|
||||
|
||||
@ -77,10 +77,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
likelihood = st.StochasticTensor(
|
||||
NormalNotParam, mu=prior, sigma=sigma)
|
||||
prior_2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
|
||||
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
|
||||
loss = tf.square(tf.identity(likelihood) - mu)
|
||||
part_loss = tf.square(tf.identity(prior) - mu)
|
||||
@ -155,9 +154,7 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
dt = st.StochasticTensor(NormalNotParam,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
|
||||
loss_fn=None)
|
||||
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
||||
|
||||
@ -166,8 +163,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
dt1 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
dt2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
||||
|
||||
sl_all = sg.surrogate_loss([loss])
|
||||
@ -186,8 +183,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
|
||||
def testBuildsMapOfUpstreamNodes(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
out1 = dt1.value() + 1.
|
||||
out2 = dt2.value() + 2.
|
||||
x = out1 + out2
|
||||
@ -197,11 +194,11 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
self.assertEqual(dep_map[dt2], set([x, y]))
|
||||
|
||||
def testHandlesStackedStochasticNodes(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
out1 = dt1.value() + 1.
|
||||
dt2 = st.StochasticTensor(distributions.Normal, mu=out1, sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
|
||||
x = dt2.value() + 2.
|
||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
y = dt3.value() * 3.
|
||||
dep_map = sg._stochastic_dependencies_map([x, y])
|
||||
self.assertEqual(dep_map[dt1], set([x]))
|
||||
@ -209,10 +206,10 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
self.assertEqual(dep_map[dt3], set([y]))
|
||||
|
||||
def testTraversesControlInputs(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
logits = dt1.value() * 3.
|
||||
dt2 = st.StochasticTensor(distributions.Bernoulli, logits=logits)
|
||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
|
||||
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
x = dt3.value()
|
||||
y = tf.ones((2, 2)) * 4.
|
||||
z = tf.ones((2, 2)) * 3.
|
||||
|
@ -35,19 +35,19 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
||||
|
||||
prior_default = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(
|
||||
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
||||
prior_0 = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
dist_value_type=st.SampleAndReshapeValue())
|
||||
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
||||
likelihood = st.StochasticTensor(
|
||||
distributions.Normal, mu=prior, sigma=sigma2)
|
||||
distributions.Normal(mu=prior, sigma=sigma2))
|
||||
self.assertTrue(
|
||||
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
||||
|
||||
@ -77,7 +77,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
|
||||
with st.value_type(st.MeanValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
||||
|
||||
prior_mean = prior.mean()
|
||||
@ -94,7 +94,8 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior_single = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(
|
||||
mu=mu, sigma=sigma))
|
||||
|
||||
prior_single_value = prior_single.value()
|
||||
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
||||
@ -104,7 +105,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
||||
prior_double = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
|
||||
prior_double_value = prior_double.value()
|
||||
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
||||
@ -119,7 +120,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleValue()):
|
||||
prior_single = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
||||
|
||||
prior_single_value = prior_single.value()
|
||||
@ -130,7 +131,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleValue(n=2)):
|
||||
prior_double = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
|
||||
prior_double_value = prior_double.value()
|
||||
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
||||
@ -143,9 +144,9 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
mu = [0.0, -1.0, 1.0]
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.MeanValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
entropy = prior.entropy()
|
||||
deep_entropy = prior.entropy()
|
||||
deep_entropy = prior.distribution.entropy()
|
||||
expected_deep_entropy = distributions.Normal(
|
||||
mu=mu, sigma=sigma).entropy()
|
||||
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
||||
@ -159,17 +160,15 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
# With default
|
||||
with st.value_type(st.MeanValue(stop_gradient=True)):
|
||||
dt = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
loss = dt.loss([tf.constant(2.0)])
|
||||
self.assertTrue(loss is not None)
|
||||
self.assertAllClose(dt.distribution.log_prob(mu).eval() * 2.0,
|
||||
loss.eval())
|
||||
self.assertAllClose(
|
||||
dt.distribution.log_prob(mu).eval() * 2.0, loss.eval())
|
||||
|
||||
# With passed-in loss_fn.
|
||||
dt = st.StochasticTensor(
|
||||
distributions.Normal,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
dist_value_type=st.MeanValue(stop_gradient=True),
|
||||
loss_fn=sge.get_score_function_with_constant_baseline(
|
||||
baseline=tf.constant(8.0)))
|
||||
@ -204,7 +203,7 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
obs = tf.zeros((2, 3))
|
||||
z = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
||||
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||
[obs_val, z_val] = sess.run([obs, z.value()])
|
||||
self.assertAllEqual(obs_val, z_val)
|
||||
|
||||
@ -216,13 +215,13 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.placeholder(tf.float32)
|
||||
obs = tf.placeholder(tf.float32)
|
||||
z = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
||||
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||
|
||||
mu2 = tf.placeholder(tf.float32, shape=[None])
|
||||
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
||||
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
z2 = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu2, sigma=sigma2, value=obs2)
|
||||
distributions.Normal(mu=mu2, sigma=sigma2), value=obs2)
|
||||
|
||||
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
||||
self.assertEqual(coll, [z, z2])
|
||||
@ -230,27 +229,19 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
def testConstructionErrors(self):
|
||||
mu = [0., 0.]
|
||||
sigma = [1., 1.]
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((3,)))
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((3, 1)))
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((1, 2), dtype=tf.int32))
|
||||
|
||||
|
||||
class AutomaticDistributionImportTest(tf.test.TestCase):
|
||||
|
||||
def testImportNormal(self):
|
||||
self.assertTrue(hasattr(st, "NormalTensor"))
|
||||
self.assertTrue(callable(st.NormalTensor))
|
||||
norm = st.NormalTensor(mu=0.0, sigma=1.0)
|
||||
self.assertEqual(type(norm).__name__, "NormalTensor")
|
||||
self.assertTrue(isinstance(norm, st.NormalTensor))
|
||||
self.assertTrue(isinstance(norm, st.StochasticTensor))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros((3,)))
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros((3, 1)))
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros(
|
||||
(1, 2), dtype=tf.int32))
|
||||
|
@ -44,7 +44,7 @@ def mini_vae():
|
||||
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
||||
prior = distributions.Normal(mu=0., sigma=1.)
|
||||
variational = st.StochasticTensor(
|
||||
distributions.Normal, mu=inference_net(x, 1), sigma=1.)
|
||||
distributions.Normal(mu=inference_net(x, 1), sigma=1.))
|
||||
vi.register_prior(variational, prior)
|
||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||
@ -101,7 +101,7 @@ class VariationalInferenceTest(tf.test.TestCase):
|
||||
|
||||
prior = distributions.Bernoulli(0.5)
|
||||
variational = st.StochasticTensor(
|
||||
NormalNoEntropy, mu=inference_net(x, 1), sigma=1.)
|
||||
NormalNoEntropy(mu=inference_net(x, 1), sigma=1.))
|
||||
vi.register_prior(variational, prior)
|
||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||
|
@ -44,7 +44,6 @@ from __future__ import print_function
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
import six
|
||||
@ -79,10 +78,6 @@ class BaseStochasticTensor(object):
|
||||
def graph(self):
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def input_dict(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def value(self, name=None):
|
||||
pass
|
||||
@ -120,6 +115,7 @@ class BaseStochasticTensor(object):
|
||||
# pylint: disable=protected-access
|
||||
ops.register_tensor_conversion_function(
|
||||
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@ -223,8 +219,8 @@ class SampleAndReshapeValue(_StochasticValueType):
|
||||
st_value = st.value()
|
||||
assertEqual(st_value.get_shape(), (4, 3))
|
||||
|
||||
dt_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||
assertEqual(dt_value_val.shape, (4, 3))
|
||||
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||
assertEqual(st_value_val.shape, (4, 3))
|
||||
```
|
||||
"""
|
||||
|
||||
@ -312,17 +308,16 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
||||
|
||||
def __init__(self,
|
||||
dist_cls,
|
||||
name=None,
|
||||
dist,
|
||||
name="StochasticTensor",
|
||||
dist_value_type=None,
|
||||
loss_fn=sge.score_function,
|
||||
**dist_args):
|
||||
loss_fn=sge.score_function):
|
||||
"""Construct a `StochasticTensor`.
|
||||
|
||||
`StochasticTensor` will instantiate a distribution from `dist_cls` and
|
||||
`dist_args` and its `value` method will return the same value each time
|
||||
it is called. What `value` is returned is controlled by the
|
||||
`dist_value_type` (defaults to `SampleAndReshapeValue`).
|
||||
`StochasticTensor` is backed by the `dist` distribution and its `value`
|
||||
method will return the same value each time it is called. What `value` is
|
||||
returned is controlled by the `dist_value_type` (defaults to
|
||||
`SampleAndReshapeValue`).
|
||||
|
||||
Some distributions' sample functions are not differentiable (e.g. a sample
|
||||
from a discrete distribution like a Bernoulli) and so to differentiate
|
||||
@ -338,28 +333,25 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
`MeanValueType` or if `loss_fn=None`.
|
||||
|
||||
Args:
|
||||
dist_cls: a `Distribution` class.
|
||||
dist: an instance of `Distribution`.
|
||||
name: a name for this `StochasticTensor` and its ops.
|
||||
dist_value_type: a `_StochasticValueType`, which will determine what the
|
||||
`value` of this `StochasticTensor` will be. If not provided, the
|
||||
value type set with the `value_type` context manager will be used.
|
||||
loss_fn: callable that takes `(st, st.value(), influenced_loss)`, where
|
||||
loss_fn: callable that takes
|
||||
`(st, st.value(), influenced_loss)`, where
|
||||
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
||||
default, `loss_fn` is the `score_function`, or more precisely, the
|
||||
integral of the score function, such that when the gradient is taken,
|
||||
the score function results. See the `stochastic_gradient_estimators`
|
||||
module for additional loss functions and baselines.
|
||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
||||
construction.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dist_cls` is not a `Distribution`.
|
||||
TypeError: if `dist` is not an instance of `Distribution`.
|
||||
TypeError: if `loss_fn` is not `callable`.
|
||||
"""
|
||||
if not issubclass(dist_cls, distributions.Distribution):
|
||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
||||
self._dist_cls = dist_cls
|
||||
self._dist_args = dist_args
|
||||
if not isinstance(dist, distributions.Distribution):
|
||||
raise TypeError("dist must be an instance of Distribution")
|
||||
if dist_value_type is None:
|
||||
try:
|
||||
self._value_type = get_current_value_type()
|
||||
@ -371,24 +363,17 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
with value_type(dist_value_type):
|
||||
self._value_type = get_current_value_type()
|
||||
|
||||
self._value_type.declare_inputs(self, dist_args)
|
||||
|
||||
if loss_fn is not None and not callable(loss_fn):
|
||||
raise TypeError("loss_fn must be callable")
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
with ops.name_scope(name, "StochasticTensor",
|
||||
dist_args.values()) as scope:
|
||||
with ops.name_scope(name) as scope:
|
||||
self._name = scope
|
||||
self._dist = dist_cls(**dist_args)
|
||||
self._dist = dist
|
||||
self._value = self._create_value()
|
||||
|
||||
super(StochasticTensor, self).__init__()
|
||||
|
||||
@property
|
||||
def input_dict(self):
|
||||
return self._dist_args
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return self._value_type
|
||||
@ -397,9 +382,6 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
def distribution(self):
|
||||
return self._dist
|
||||
|
||||
def clone(self, name=None, **dist_args):
|
||||
return StochasticTensor(self._dist_cls, name=name, **dist_args)
|
||||
|
||||
def _create_value(self):
|
||||
"""Create the value Tensor based on the value type, store as self._value."""
|
||||
|
||||
@ -494,33 +476,28 @@ class ObservedStochasticTensor(StochasticTensor):
|
||||
"""A StochasticTensor with an observed value."""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
def __init__(self, dist_cls, value, name=None, **dist_args):
|
||||
def __init__(self, dist, value, name=None):
|
||||
"""Construct an `ObservedStochasticTensor`.
|
||||
|
||||
`ObservedStochasticTensor` will instantiate a distribution from `dist_cls`
|
||||
and `dist_args` but use the provided value instead of sampling from the
|
||||
distribution. The provided value argument must be appropriately shaped
|
||||
to have come from the constructed distribution.
|
||||
`ObservedStochasticTensor` is backed by distribution `dist` and uses the
|
||||
provided value instead of using the current value type to draw a value from
|
||||
the distribution. The provided value argument must be appropriately shaped
|
||||
to have come from the distribution.
|
||||
|
||||
Args:
|
||||
dist_cls: a `Distribution` class.
|
||||
dist: an instance of `Distribution`.
|
||||
value: a Tensor containing the observed value
|
||||
name: a name for this `ObservedStochasticTensor` and its ops.
|
||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
||||
construction.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dist_cls` is not a `Distribution`.
|
||||
TypeError: if `dist` is not an instance of `Distribution`.
|
||||
ValueError: if `value` is not compatible with the distribution.
|
||||
"""
|
||||
if not issubclass(dist_cls, distributions.Distribution):
|
||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
||||
self._dist_cls = dist_cls
|
||||
self._dist_args = dist_args
|
||||
with ops.name_scope(name, "ObservedStochasticTensor",
|
||||
list(dist_args.values()) + [value]) as scope:
|
||||
if not isinstance(dist, distributions.Distribution):
|
||||
raise TypeError("dist must be an instance of Distribution")
|
||||
with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope:
|
||||
self._name = scope
|
||||
self._dist = dist_cls(**dist_args)
|
||||
self._dist = dist
|
||||
dist_shape = self._dist.get_batch_shape().concatenate(
|
||||
self._dist.get_event_shape())
|
||||
value = ops.convert_to_tensor(value)
|
||||
@ -538,7 +515,7 @@ class ObservedStochasticTensor(StochasticTensor):
|
||||
"sample from the distribution %s." % (value_shape, dist_shape))
|
||||
if value.dtype != self._dist.dtype:
|
||||
raise ValueError("Type of observed value (%s) does not match type of "
|
||||
"distribuiton (%s)." % (value.dtype, self._dist.dtype))
|
||||
"distribution (%s)." % (value.dtype, self._dist.dtype))
|
||||
self._value = array_ops.identity(value)
|
||||
# pylint: disable=non-parent-init-called
|
||||
BaseStochasticTensor.__init__(self)
|
||||
@ -557,39 +534,3 @@ __all__ = [
|
||||
"value_type",
|
||||
"get_current_value_type",
|
||||
]
|
||||
|
||||
_globals = globals()
|
||||
# pylint: disable=redefined-builtin
|
||||
__doc__ += "\n\n## Automatically Generated StochasticTensors\n\n"
|
||||
# pylint: enable=redefined-builtin
|
||||
for _name in sorted(dir(distributions)):
|
||||
_candidate = getattr(distributions, _name)
|
||||
if (inspect.isclass(_candidate)
|
||||
and _candidate != distributions.Distribution
|
||||
and issubclass(_candidate, distributions.Distribution)):
|
||||
_local_name = "%sTensor" % _name
|
||||
|
||||
class _WrapperTensor(StochasticTensor):
|
||||
_my_candidate = _candidate
|
||||
|
||||
def __init__(self, name=None, dist_value_type=None,
|
||||
loss_fn=sge.score_function, **dist_args):
|
||||
StochasticTensor.__init__(
|
||||
self,
|
||||
dist_cls=self._my_candidate,
|
||||
name=name,
|
||||
dist_value_type=dist_value_type,
|
||||
loss_fn=loss_fn, **dist_args)
|
||||
|
||||
_WrapperTensor.__name__ = _local_name
|
||||
_WrapperTensor.__doc__ = (
|
||||
"`%s` is a `StochasticTensor` backed by the distribution `%s`."""
|
||||
% (_local_name, _name))
|
||||
_globals[_local_name] = _WrapperTensor
|
||||
del _WrapperTensor
|
||||
del _candidate
|
||||
|
||||
__all__.append(_local_name)
|
||||
__doc__ += "@@%s\n" % _local_name
|
||||
|
||||
del _local_name
|
||||
|
@ -126,7 +126,7 @@ def get_stochastic_variable(getter,
|
||||
|
||||
dist_kwargs = dist_kwargs or {}
|
||||
dist_kwargs.update(params)
|
||||
sample = st.StochasticTensor(dist_cls, **dist_kwargs)
|
||||
sample = st.StochasticTensor(dist_cls(**dist_kwargs))
|
||||
|
||||
if prior is not None:
|
||||
if callable(prior):
|
||||
|
@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
||||
|
||||
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
x = np.arange(9)
|
||||
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||
expected = np.array(
|
||||
[[0., 0., 0.],
|
||||
[1., 2., 0.],
|
||||
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
||||
actual = distribution_util.fill_lower_triangular(x)
|
||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||
self.assertAllEqual(expected, actual.eval())
|
||||
self.assertAllEqual(
|
||||
np.concatenate([np.ones(6, dtype=np.float32),
|
||||
np.zeros(3, dtype=np.float32)]),
|
||||
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||
|
||||
def testCorrectlyMakesBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
|
@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||
"""
|
||||
with ops.name_scope(name, values=(x,)):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
ndims = x.get_shape().ndims
|
||||
if ndims is not None and x.get_shape()[-1].value is not None:
|
||||
if (x.get_shape().ndims is not None and
|
||||
x.get_shape()[-1].value is not None):
|
||||
d = x.get_shape()[-1].value
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||
final_shape = x.get_shape()[:-1].concatenate(
|
||||
tensor_shape.TensorShape([n, n]))
|
||||
else:
|
||||
ndims = array_ops.rank(x)
|
||||
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||
array_ops.tile([tril_ids], [m, 1])])
|
||||
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||
|
||||
y = array_ops.gather_nd(y, idx)
|
||||
if x.get_shape().ndims == 1:
|
||||
# Prefer using gather because it has a gradient.
|
||||
# We wrap the result in a list so downstream logic "just works."
|
||||
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||
else:
|
||||
y = array_ops.gather_nd(y, idx)
|
||||
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||
|
||||
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||
|
@ -571,9 +571,8 @@ class WALSModel(object):
|
||||
extras = size % num_shards
|
||||
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
||||
(ids - extras) // ids_per_shard)
|
||||
new_ids = tf.select(assignments < extras,
|
||||
ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
return assignments, new_ids
|
||||
return func
|
||||
|
||||
|
@ -36,7 +36,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
variables = tf.local_variables()
|
||||
self.assertEquals(2, len(variables))
|
||||
self.assertRaises(tf.OpError, sess.run, variables)
|
||||
tf.initialize_variables(variables).run()
|
||||
tf.variables_initializer(variables).run()
|
||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||
|
||||
def testLocalVariableNameAndShape(self):
|
||||
@ -51,7 +51,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable(0)
|
||||
self.assertFalse(a in tf.all_variables())
|
||||
self.assertFalse(a in tf.global_variables())
|
||||
self.assertTrue(a in tf.local_variables())
|
||||
|
||||
def testLocalVariableNotInVariablesToRestore(self):
|
||||
@ -82,7 +82,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
def testInitializedVariableValue(self):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||
sess.run(tf.initialize_local_variables())
|
||||
sess.run(tf.local_variables_initializer())
|
||||
self.assertAllEqual(a.eval(), [0]*5)
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertTrue(a in tf.all_variables())
|
||||
self.assertTrue(a in tf.global_variables())
|
||||
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
||||
self.assertFalse(a in tf.local_variables())
|
||||
|
||||
@ -474,7 +474,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.model_variable(
|
||||
'a', [5], initializer=tf.ones_initializer)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
self.assertAllEqual(a.eval(), [1]*5)
|
||||
|
||||
def testDeviceFn(self):
|
||||
@ -667,7 +667,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
||||
var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(assign_op, feed_dict)
|
||||
@ -697,7 +697,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
||||
var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(assign_op, feed_dict)
|
||||
@ -725,7 +725,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -754,7 +754,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -786,7 +786,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
var_value = var_names_to_values[var_name]
|
||||
var_list.append(tf.Variable(var_value, name=var_name))
|
||||
saver = tf.train.Saver(var_list)
|
||||
init_op = tf.initialize_variables(var_list)
|
||||
init_op = tf.variables_initializer(var_list)
|
||||
sess.run(init_op)
|
||||
# Save the initialized values in the file at 'checkpoint_dir'
|
||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||
@ -808,7 +808,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(op, feed_dict)
|
||||
@ -859,7 +859,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(op, feed_dict)
|
||||
@ -890,7 +890,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
var_value = var_names_to_values[var_name]
|
||||
var_list.append(tf.Variable(var_value, name=var_name))
|
||||
saver = tf.train.Saver(var_list)
|
||||
init_op = tf.initialize_variables(var_list)
|
||||
init_op = tf.variables_initializer(var_list)
|
||||
sess.run(init_op)
|
||||
# Save the initialized values in the file at 'checkpoint_dir'
|
||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||
@ -912,7 +912,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -938,7 +938,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
@ -961,7 +961,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore, reshape_variables=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -989,7 +989,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
with self.assertRaises(tf.errors.NotFoundError):
|
||||
@ -1015,7 +1015,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
ignore_missing_vars=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -1044,7 +1044,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
ignore_missing_vars=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
|
38
tensorflow/contrib/integrate/BUILD
Normal file
38
tensorflow/contrib/integrate/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
# Description:
|
||||
# Integration and ODE solvers for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
name = "integrate_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/odes.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "odes_test",
|
||||
srcs = ["python/ops/odes_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":integrate_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
9
tensorflow/contrib/integrate/README.md
Normal file
9
tensorflow/contrib/integrate/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Integration and ODE solvers for TensorFlow
|
||||
|
||||
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
|
||||
contains a single function, `odeint`, for integrating ordinary differential
|
||||
equations.
|
||||
|
||||
Maintainers:
|
||||
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||
- Marc Coram (mcoram@google.com, github.com/mcoram)
|
64
tensorflow/contrib/integrate/__init__.py
Normal file
64
tensorflow/contrib/integrate/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Integration and ODE solvers for TensorFlow.
|
||||
|
||||
## Example: Lorenz attractor
|
||||
|
||||
We can use `odeint` to solve the
|
||||
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||
differential equations, a prototypical example of chaotic dynamics:
|
||||
|
||||
```python
|
||||
rho = 28.0
|
||||
sigma = 10.0
|
||||
beta = 8.0/3.0
|
||||
|
||||
def lorenz_equation(state, t):
|
||||
x, y, z = tf.unpack(state)
|
||||
dx = sigma * (y - x)
|
||||
dy = x * (rho - z) - y
|
||||
dz = x * y - beta * z
|
||||
return tf.pack([dx, dy, dz])
|
||||
|
||||
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||
t = np.linspace(0, 50, num=5000)
|
||||
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||
lorenz_equation, init_state, t, full_output=True)
|
||||
|
||||
sess = tf.Session()
|
||||
state, info = sess.run([tensor_state, tensor_info])
|
||||
x, y, z = state.T
|
||||
plt.plot(x, z)
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||
</div>
|
||||
|
||||
## Ops
|
||||
|
||||
@@odeint
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
@ -0,0 +1,503 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""ODE solvers for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
|
||||
_ButcherTableau = collections.namedtuple(
|
||||
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
|
||||
|
||||
# Parameters from Shampine (1986), section 4.
|
||||
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
|
||||
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
|
||||
beta=[[1/5],
|
||||
[3/40, 9/40],
|
||||
[44/45, -56/15, 32/9],
|
||||
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
|
||||
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
|
||||
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
|
||||
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
|
||||
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
|
||||
c_error=[1951/21600 - 35/384,
|
||||
0,
|
||||
22642/50085 - 500/1113,
|
||||
451/720 - 125/192,
|
||||
-12231/42400 - -2187/6784,
|
||||
649/6300 - 11/84,
|
||||
1/60],
|
||||
)
|
||||
|
||||
|
||||
def _possibly_nonzero(x):
|
||||
return isinstance(x, ops.Tensor) or x != 0
|
||||
|
||||
|
||||
def _scaled_dot_product(scale, xs, ys, name=None):
|
||||
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
||||
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
|
||||
# Some of the parameters in our Butcher tableau include zeros. Using
|
||||
# _possibly_nonzero lets us avoid wasted computation.
|
||||
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
|
||||
if _possibly_nonzero(x) or _possibly_nonzero(y)],
|
||||
name=scope)
|
||||
|
||||
|
||||
def _dot_product(xs, ys, name=None):
|
||||
"""Calculate the vector inner product between two lists of Tensors."""
|
||||
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
|
||||
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
|
||||
|
||||
|
||||
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
|
||||
name=None):
|
||||
"""Take an arbitrary Runge-Kutta step and estimate error.
|
||||
|
||||
Args:
|
||||
func: Function to evaluate like `func(y, t)` to compute the time derivative
|
||||
of `y`.
|
||||
y0: Tensor initial value for the state.
|
||||
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
|
||||
t0: float64 scalar Tensor giving the initial time.
|
||||
dt: float64 scalar Tensor giving the size of the desired time step.
|
||||
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
||||
step.
|
||||
name: optional name for the operation.
|
||||
|
||||
Returns:
|
||||
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
||||
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
||||
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
||||
calculating these terms.
|
||||
"""
|
||||
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
f0 = ops.convert_to_tensor(f0, name='f0')
|
||||
t0 = ops.convert_to_tensor(t0, name='t0')
|
||||
dt = ops.convert_to_tensor(dt, name='dt')
|
||||
dt_cast = math_ops.cast(dt, y0.dtype)
|
||||
|
||||
k = [f0]
|
||||
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
||||
ti = t0 + alpha_i * dt
|
||||
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
|
||||
k.append(func(yi, ti))
|
||||
|
||||
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
|
||||
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
||||
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
|
||||
|
||||
y1 = array_ops.identity(yi, name='%s/y1' % scope)
|
||||
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
|
||||
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
|
||||
name='%s/y1_error' % scope)
|
||||
return (y1, f1, y1_error, k)
|
||||
|
||||
|
||||
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
||||
"""Fit coefficients for 4th order polynomial interpolation.
|
||||
|
||||
Args:
|
||||
y0: function value at the start of the interval.
|
||||
y1: function value at the end of the interval.
|
||||
y_mid: function value at the mid-point of the interval.
|
||||
f0: derivative value at the start of the interval.
|
||||
f1: derivative value at the end of the interval.
|
||||
dt: width of the interval.
|
||||
|
||||
Returns:
|
||||
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
||||
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
||||
between 0 (start of interval) and 1 (end of interval).
|
||||
"""
|
||||
# a, b, c, d, e = sympy.symbols('a b c d e')
|
||||
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
|
||||
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
||||
# sympy.solve([p.subs(x, 0) - y0,
|
||||
# p.subs(x, 1 / 2) - y_mid,
|
||||
# p.subs(x, 1) - y1,
|
||||
# (p.diff(x) / dt).subs(x, 0) - f0,
|
||||
# (p.diff(x) / dt).subs(x, 1) - f1],
|
||||
# [a, b, c, d, e])
|
||||
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
|
||||
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
|
||||
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
|
||||
# d: dt*f0,
|
||||
# e: y0}
|
||||
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
|
||||
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
|
||||
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
|
||||
d = dt * f0
|
||||
e = y0
|
||||
return [a, b, c, d, e]
|
||||
|
||||
|
||||
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
|
||||
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
||||
with ops.name_scope('interp_fit_rk'):
|
||||
dt = math_ops.cast(dt, y0.dtype)
|
||||
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
|
||||
f0 = k[0]
|
||||
f1 = k[-1]
|
||||
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
||||
|
||||
|
||||
def _interp_evaluate(coefficients, t0, t1, t):
|
||||
"""Evaluate polynomial interpolation at the given time point.
|
||||
|
||||
Args:
|
||||
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
||||
t0: scalar float64 Tensor giving the start of the interval.
|
||||
t1: scalar float64 Tensor giving the end of the interval.
|
||||
t: scalar float64 Tensor giving the desired interpolation point.
|
||||
|
||||
Returns:
|
||||
Polynomial interpolation of the coefficients at time `t`.
|
||||
"""
|
||||
with ops.name_scope('interp_evaluate'):
|
||||
t0 = ops.convert_to_tensor(t0)
|
||||
t1 = ops.convert_to_tensor(t1)
|
||||
t = ops.convert_to_tensor(t)
|
||||
|
||||
dtype = coefficients[0].dtype
|
||||
|
||||
assert_op = control_flow_ops.Assert(
|
||||
(t0 <= t) & (t <= t1),
|
||||
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
|
||||
|
||||
xs = [constant_op.constant(1, dtype), x]
|
||||
for _ in range(2, len(coefficients)):
|
||||
xs.append(xs[-1] * x)
|
||||
|
||||
return _dot_product(coefficients, reversed(xs))
|
||||
|
||||
|
||||
def _optimal_step_size(last_step,
|
||||
error_ratio,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
order=5,
|
||||
name=None):
|
||||
"""Calculate the optimal size for the next Runge-Kutta step."""
|
||||
with ops.name_scope(
|
||||
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
|
||||
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
|
||||
exponent = math_ops.cast(1 / order, last_step.dtype)
|
||||
# this looks more complex than necessary, but importantly it keeps
|
||||
# error_ratio in the numerator so we can't divide by zero:
|
||||
factor = math_ops.maximum(
|
||||
1 / ifactor,
|
||||
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
|
||||
return math_ops.div(last_step, factor, name=scope)
|
||||
|
||||
|
||||
def _abs_square(x):
|
||||
if x.dtype.is_complex:
|
||||
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
|
||||
else:
|
||||
return math_ops.square(x)
|
||||
|
||||
|
||||
def _ta_append(tensor_array, value):
|
||||
"""Append a value to the end of a tf.TensorArray."""
|
||||
return tensor_array.write(tensor_array.size(), value)
|
||||
|
||||
|
||||
class _RungeKuttaState(collections.namedtuple(
|
||||
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
||||
"""Saved state of the Runge Kutta solver.
|
||||
|
||||
Attributes:
|
||||
y1: Tensor giving the function value at the end of the last time step.
|
||||
f1: Tensor giving derivative at the end of the last time step.
|
||||
t0: scalar float64 Tensor giving start of the last time step.
|
||||
t1: scalar float64 Tensor giving end of the last time step.
|
||||
dt: scalar float64 Tensor giving the size for the next time step.
|
||||
interp_coef: list of Tensors giving coefficients for polynomial
|
||||
interpolation between `t0` and `t1`.
|
||||
"""
|
||||
|
||||
|
||||
class _History(collections.namedtuple(
|
||||
'_History', 'integrate_points, error_ratio')):
|
||||
"""Saved integration history for use in `info_dict`.
|
||||
|
||||
Attributes:
|
||||
integrate_points: tf.TensorArray storing integrating time points.
|
||||
error_ratio: tf.TensorArray storing computed error ratios at each
|
||||
integration step.
|
||||
"""
|
||||
|
||||
|
||||
def _dopri5(func,
|
||||
y0,
|
||||
t,
|
||||
rtol,
|
||||
atol,
|
||||
full_output=False,
|
||||
first_step=None,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
max_num_steps=1000,
|
||||
name=None):
|
||||
"""Solve an ODE for `odeint` using method='dopri5'."""
|
||||
|
||||
if first_step is None:
|
||||
# at some point, we might want to switch to picking the step size
|
||||
# automatically
|
||||
first_step = 1.0
|
||||
|
||||
with ops.name_scope(
|
||||
name, 'dopri5',
|
||||
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
|
||||
|
||||
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
|
||||
name='first_step')
|
||||
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
|
||||
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
|
||||
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
|
||||
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
|
||||
name='max_num_steps')
|
||||
|
||||
def adaptive_runge_kutta_step(rk_state, history, n_steps):
|
||||
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
|
||||
y0, f0, _, t0, dt, interp_coeff = rk_state
|
||||
with ops.name_scope('assertions'):
|
||||
check_underflow = control_flow_ops.Assert(
|
||||
t0 + dt > t0, ['underflow in dt', dt])
|
||||
check_max_num_steps = control_flow_ops.Assert(
|
||||
n_steps < max_num_steps, ['max_num_steps exceeded'])
|
||||
check_numerics = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
|
||||
['non-finite values in state `y`', y0])
|
||||
with ops.control_dependencies(
|
||||
[check_underflow, check_max_num_steps, check_numerics]):
|
||||
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
|
||||
|
||||
with ops.name_scope('error_ratio'):
|
||||
# We use the same approach as the dopri5 fortran code.
|
||||
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
|
||||
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
|
||||
# Could also use reduce_maximum here.
|
||||
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
|
||||
accept_step = error_ratio <= 1
|
||||
|
||||
with ops.name_scope('update/rk_state'):
|
||||
# If we don't accept the step, the _RungeKuttaState will be useless
|
||||
# (covering a time-interval of size 0), but that's OK, because in such
|
||||
# cases we always immediately take another Runge-Kutta step.
|
||||
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
|
||||
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
|
||||
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
|
||||
interp_coeff = control_flow_ops.cond(
|
||||
accept_step,
|
||||
lambda: _interp_fit_rk(y0, y1, k, dt),
|
||||
lambda: interp_coeff)
|
||||
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
|
||||
rk_state = _RungeKuttaState(
|
||||
y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
||||
|
||||
with ops.name_scope('update/history'):
|
||||
history = _History(_ta_append(history.integrate_points, t0 + dt),
|
||||
_ta_append(history.error_ratio, error_ratio))
|
||||
return rk_state, history, n_steps + 1
|
||||
|
||||
def interpolate(solution, history, rk_state, i):
|
||||
"""Interpolate through the next time point, integrating as necessary."""
|
||||
with ops.name_scope('interpolate'):
|
||||
rk_state, history, _ = control_flow_ops.while_loop(
|
||||
lambda rk_state, *_: t[i] > rk_state.t1,
|
||||
adaptive_runge_kutta_step,
|
||||
(rk_state, history, 0),
|
||||
name='integrate_loop')
|
||||
y = _interp_evaluate(
|
||||
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
|
||||
solution = solution.write(i, y)
|
||||
return solution, history, rk_state, i + 1
|
||||
|
||||
assert_increasing = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(t[1:] > t[:-1]),
|
||||
['`t` must be monotonic increasing'])
|
||||
with ops.control_dependencies([assert_increasing]):
|
||||
num_times = array_ops.size(t)
|
||||
|
||||
solution = tensor_array_ops.TensorArray(
|
||||
y0.dtype, size=num_times).write(0, y0)
|
||||
history = _History(
|
||||
integrate_points=tensor_array_ops.TensorArray(
|
||||
t.dtype, size=0, dynamic_size=True),
|
||||
error_ratio=tensor_array_ops.TensorArray(
|
||||
rtol.dtype, size=0, dynamic_size=True))
|
||||
rk_state = _RungeKuttaState(
|
||||
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
|
||||
|
||||
solution, history, _, _ = control_flow_ops.while_loop(
|
||||
lambda _, __, ___, i: i < num_times,
|
||||
interpolate,
|
||||
(solution, history, rk_state, 1),
|
||||
name='interpolate_loop')
|
||||
|
||||
y = solution.pack(name=scope)
|
||||
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
|
||||
if not full_output:
|
||||
return y
|
||||
else:
|
||||
integrate_points = history.integrate_points.pack()
|
||||
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
|
||||
'integrate_points': integrate_points,
|
||||
'error_ratio': history.error_ratio.pack()}
|
||||
return (y, info_dict)
|
||||
|
||||
|
||||
def odeint(func,
|
||||
y0,
|
||||
t,
|
||||
rtol=1e-6,
|
||||
atol=1e-12,
|
||||
method=None,
|
||||
options=None,
|
||||
full_output=False,
|
||||
name=None):
|
||||
"""Integrate a system of ordinary differential equations.
|
||||
|
||||
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||
|
||||
```
|
||||
dy/dt = func(y, t), y(t[0]) = y0
|
||||
```
|
||||
|
||||
where y is a Tensor of any shape.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||
=> [1, exp(-1), exp(-2)]
|
||||
```
|
||||
|
||||
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||
`y0` and `t`.
|
||||
|
||||
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||
|
||||
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||
doi:10.2307/2008219
|
||||
|
||||
Args:
|
||||
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||
`t` into a Tensor of state derivatives with respect to time.
|
||||
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||
have any floating point or complex dtype.
|
||||
t: 1-D Tensor holding a sequence of time points for which to solve for
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||
float64 dtype.
|
||||
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
||||
per element of `y`.
|
||||
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
||||
per element of `y`.
|
||||
method: optional string indicating the integration method to use. Currently,
|
||||
the only valid option is `'dopri5'`.
|
||||
options: optional dict of configuring options for the indicated integration
|
||||
method. Can only be provided if a `method` is explicitly set. For
|
||||
`'dopri5'`, valid options include:
|
||||
* first_step: an initial guess for the size of the first integration
|
||||
(current default: 1.0, but may later be changed to use heuristics based
|
||||
on the gradient).
|
||||
* safety: safety factor for adaptive step control, generally a constant
|
||||
in the range 0.8-1 (default: 0.9).
|
||||
* ifactor: maximum factor by which the adaptive step may be increased
|
||||
(default: 10.0).
|
||||
* dfactor: maximum factor by which the adpative step may be decreased
|
||||
(default: 0.2).
|
||||
* max_num_steps: integer maximum number of integrate steps between time
|
||||
points in `t` (default: 1000).
|
||||
full_output: optional boolean. If True, `odeint` returns a tuple
|
||||
`(y, info_dict)` describing the integration process.
|
||||
name: Optional name for this operation.
|
||||
|
||||
Returns:
|
||||
y: (N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
info_dict: only if `full_output == True`. A dict with the following values:
|
||||
* num_func_evals: integer Tensor counting the number of function
|
||||
evaluations.
|
||||
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||
integration time step.
|
||||
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||
error to the error tolerance at each integration step. An ratio greater
|
||||
than 1 corresponds to rejected steps.
|
||||
|
||||
Raises:
|
||||
ValueError: if an invalid `method` is provided.
|
||||
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||
an invalid dtype.
|
||||
"""
|
||||
if method is not None and method != 'dopri5':
|
||||
raise ValueError('invalid method: %r' % method)
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
elif method is None:
|
||||
raise ValueError('cannot supply `options` without specifying `method`')
|
||||
|
||||
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
|
||||
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
|
||||
# arbitrarily nested tuple. This will help performance and usability by
|
||||
# avoiding the need to pack/unpack in user functions.
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||
raise TypeError('`y0` must have a floating point or complex floating '
|
||||
'point dtype')
|
||||
|
||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||
if not t.dtype.is_floating:
|
||||
raise TypeError('`t` must have a floating point dtype')
|
||||
|
||||
error_dtype = abs(y0).dtype
|
||||
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
|
||||
|
||||
return _dopri5(func, y0, t,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
full_output=full_output,
|
||||
name=scope,
|
||||
**options)
|
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for ODE solvers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.integrate.python.ops import odes
|
||||
|
||||
|
||||
class OdeIntTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(OdeIntTest, self).setUp()
|
||||
# simple defaults (solution is a sin-wave)
|
||||
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
|
||||
self.func = lambda y, t: tf.matmul(matrix, y)
|
||||
self.y0 = np.array([[1.0], [0.0]])
|
||||
|
||||
def test_odeint_exp(self):
|
||||
# Test odeint by an exponential function:
|
||||
# dy / dt = y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(t).
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
self.assertIn('odeint', y_solved.name)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_complex(self):
|
||||
# Test a complex, linear ODE:
|
||||
# dy / dt = k * y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(k * t).
|
||||
k = 1j - 0.1
|
||||
func = lambda y, t: k * y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(k * t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_riccati(self):
|
||||
# The Ricatti equation is:
|
||||
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
|
||||
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
|
||||
func = lambda t, y: (y - t)**2 + 1.0
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = 1.0 / (2.0 - t) + t
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_2d_linear(self):
|
||||
# Solve the 2D linear differential equation:
|
||||
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
|
||||
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
|
||||
# y1(0) = 0.0,
|
||||
# y2(0) = 1.0.
|
||||
# Its analytical solution is
|
||||
# y1 = sin(4.0 * t) * exp(3.0 * t),
|
||||
# y2 = cos(4.0 * t) * exp(3.0 * t).
|
||||
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
|
||||
func = lambda y, t: tf.matmul(matrix, y)
|
||||
|
||||
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
|
||||
y_true = np.zeros((len(t), 2, 1))
|
||||
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
|
||||
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
|
||||
self.assertAllClose(y_true, y_solved, atol=1e-5)
|
||||
|
||||
def test_odeint_higher_rank(self):
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for shape in [(), (1,), (1, 1)]:
|
||||
expected_shape = (len(t),) + shape
|
||||
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
self.assertEquals(y_solved.shape, expected_shape)
|
||||
|
||||
def test_odeint_all_dtypes(self):
|
||||
func = lambda y, t: y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
|
||||
for t_dtype in [tf.float32, tf.float64]:
|
||||
y0 = tf.cast(1.0, y0_dtype)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
expected = np.asarray(np.exp(t))
|
||||
self.assertAllClose(y_solved, expected, rtol=1e-5)
|
||||
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
|
||||
|
||||
def test_odeint_required_dtypes(self):
|
||||
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
|
||||
|
||||
def test_odeint_runtime_errors(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'cannot supply `options` without'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
|
||||
options={'first_step': 1.0})
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
|
||||
options={'max_num_steps': 0})
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'max_num_steps'):
|
||||
sess.run(y)
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'monotonic increasing'):
|
||||
sess.run(y)
|
||||
|
||||
def test_odeint_different_times(self):
|
||||
# integrate steps should be independent of interpolation times
|
||||
times0 = np.linspace(0, 10, num=11, dtype=float)
|
||||
times1 = np.linspace(0, 10, num=101, dtype=float)
|
||||
|
||||
with self.test_session() as sess:
|
||||
y_solved_0, info_0 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times0, full_output=True))
|
||||
y_solved_1, info_1 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times1, full_output=True))
|
||||
|
||||
self.assertAllClose(y_solved_0, y_solved_1[::10])
|
||||
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
|
||||
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
|
||||
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
|
||||
|
||||
def test_odeint_5th_order_accuracy(self):
|
||||
t = [0, 20]
|
||||
kwargs = dict(full_output=True,
|
||||
method='dopri5',
|
||||
options=dict(max_num_steps=2000))
|
||||
with self.test_session() as sess:
|
||||
_, info_0 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
|
||||
_, info_1 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
|
||||
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
|
||||
float(info_1['integrate_points'].size),
|
||||
rtol=0.01)
|
||||
|
||||
|
||||
class StepSizeTest(tf.test.TestCase):
|
||||
|
||||
def test_error_ratio_one(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.9)
|
||||
|
||||
def test_ifactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(0.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 10.0)
|
||||
|
||||
def test_dfactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1e6))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.2)
|
||||
|
||||
|
||||
class InterpolationTest(tf.test.TestCase):
|
||||
|
||||
def test_5th_order_polynomial(self):
|
||||
# this should be an exact fit
|
||||
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
|
||||
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
|
||||
coeffs = odes._interp_fit(
|
||||
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
|
||||
times = np.linspace(0, 10, dtype=np.float32)
|
||||
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
|
||||
for t in times])
|
||||
y_expected = f(times)
|
||||
with self.test_session() as sess:
|
||||
y_actual = sess.run(y_fit)
|
||||
self.assertAllClose(y_expected, y_actual)
|
||||
|
||||
# attempt interpolation outside bounds
|
||||
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
sess.run(y_invalid)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -258,10 +258,11 @@ def optimize_loss(loss,
|
||||
grad_values = gradient
|
||||
|
||||
if grad_values is not None:
|
||||
var_name = variable.name.replace(":", "_")
|
||||
if "gradients" in summaries:
|
||||
summary.histogram("gradients/" + variable.name, grad_values)
|
||||
summary.histogram("gradients/%s" % var_name, grad_values)
|
||||
if "gradient_norm" in summaries:
|
||||
summary.scalar("gradient_norm/" + variable.name,
|
||||
summary.scalar("gradient_norm/%s" % var_name,
|
||||
clip_ops.global_norm([grad_values]))
|
||||
|
||||
if clip_gradients is not None and "gradient_norm" in summaries:
|
||||
|
@ -291,7 +291,9 @@ py_test(
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -747,13 +748,16 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
Numpy array of predicted classes (or an iterable of predicted classes if
|
||||
as_iterable is True).
|
||||
"""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
||||
return preds[head_lib.PredictionKey.CLASSES].reshape(-1)
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key].reshape(-1)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -775,20 +779,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
Numpy array of predicted probabilities (or an iterable of predicted
|
||||
probabilities if as_iterable is True).
|
||||
"""
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x, input_fn=input_fn,
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.PROBABILITIES],
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
"""See `Estimator` class."""
|
||||
# pylint: disable=protected-access
|
||||
return self._estimator._get_predict_ops(features)[
|
||||
head_lib.PredictionKey.PROBABILITIES]
|
||||
prediction_key.PredictionKey.PROBABILITIES]
|
||||
|
||||
def get_variable_names(self):
|
||||
"""Returns list of all variable names in this model.
|
||||
@ -826,9 +832,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
input_fn=input_fn or default_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(
|
||||
signature_fn or export.classification_signature_fn_with_prob),
|
||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
||||
signature_fn=(signature_fn or
|
||||
export.classification_signature_fn_with_prob),
|
||||
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
@ -1041,10 +1047,11 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
||||
head=head,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn,
|
||||
default_prediction_key=head_lib.PredictionKey.SCORES,
|
||||
default_prediction_key=prediction_key.PredictionKey.SCORES,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
"""See base class."""
|
||||
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
|
||||
head_lib.PredictionKey.SCORES]
|
||||
return super(
|
||||
DNNLinearCombinedRegressor,
|
||||
self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
|
||||
|
||||
result = _make_metrics_ops(all_metrics, features, labels,
|
||||
model_fn_ops.predictions)
|
||||
if 'loss' not in result:
|
||||
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
|
||||
if metric_key.MetricKey.LOSS not in result:
|
||||
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
|
||||
model_fn_ops.loss)
|
||||
return result
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
|
@ -24,6 +24,8 @@ from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import ops
|
||||
@ -388,17 +390,17 @@ class _RegressionHead(_Head):
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.SCORES] = array_ops.squeeze(
|
||||
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
|
||||
logits, squeeze_dims=[1])
|
||||
else:
|
||||
predictions[PredictionKey.SCORES] = logits
|
||||
predictions[prediction_key.PredictionKey.SCORES] = logits
|
||||
return predictions
|
||||
|
||||
# pylint: disable=undefined-variable
|
||||
def _create_signature_fn(self):
|
||||
def _regression_signature_fn(examples, unused_features, predictions):
|
||||
if isinstance(predictions, dict):
|
||||
score = predictions[PredictionKey.SCORES]
|
||||
score = predictions[prediction_key.PredictionKey.SCORES]
|
||||
else:
|
||||
score = predictions
|
||||
|
||||
@ -409,11 +411,12 @@ class _RegressionHead(_Head):
|
||||
return _regression_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
return {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
|
||||
class _MultiClassHead(_Head):
|
||||
@ -530,12 +533,16 @@ class _MultiClassHead(_Head):
|
||||
return self._logits_to_prediction(logits)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
# pylint: disable=missing-docstring
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
@ -546,8 +553,9 @@ class _MultiClassHead(_Head):
|
||||
if isinstance(predictions, dict):
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
classes_tensor=predictions[PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[PredictionKey.PROBABILITIES])
|
||||
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[
|
||||
prediction_key.PredictionKey.PROBABILITIES])
|
||||
else:
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
@ -558,44 +566,49 @@ class _MultiClassHead(_Head):
|
||||
return _classification_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
||||
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
if self.logits_dimension == 1:
|
||||
def _add_binary_metric(metric_key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, metric_key)] = (
|
||||
def _add_binary_metric(key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, key)] = (
|
||||
metric_spec.MetricSpec(metric_fn,
|
||||
PredictionKey.LOGISTIC,
|
||||
prediction_key.PredictionKey.LOGISTIC,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
|
||||
# Also include the streaming mean of the label as an accuracy baseline, as
|
||||
# a reminder to users.
|
||||
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
|
||||
_add_binary_metric(MetricKey.AUC, _streaming_auc)
|
||||
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
|
||||
|
||||
for threshold in self._thresholds:
|
||||
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
|
||||
_accuracy_at_threshold(threshold))
|
||||
# Precision for positive examples.
|
||||
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_precision_at_thresholds,
|
||||
threshold),)
|
||||
# Recall for positive examples.
|
||||
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_recall_at_thresholds,
|
||||
threshold))
|
||||
@ -635,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
predictions[PredictionKey.LOGITS] = logits
|
||||
predictions[prediction_key.PredictionKey.LOGITS] = logits
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
||||
return metrics
|
||||
@ -674,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
|
||||
thresholds=thresholds)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
math_ops.greater(logits, 0))
|
||||
return predictions
|
||||
|
||||
@ -849,23 +867,3 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
|
||||
return array_ops.squeeze(precision_tensor), update_op
|
||||
|
||||
return _streaming_metrics
|
||||
|
||||
|
||||
class PredictionKey(object):
|
||||
CLASSES = "classes"
|
||||
PROBABILITIES = "probabilities"
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -267,21 +268,18 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
Example:
|
||||
|
||||
```python
|
||||
education = sparse_column_with_hash_bucket(column_name="education",
|
||||
hash_bucket_size=1000)
|
||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
||||
hash_bucket_size=1000)
|
||||
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||
|
||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
||||
hash_bucket_size=10000)
|
||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||
|
||||
# Estimator using the default optimizer.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation])
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||
|
||||
# Or estimator using the FTRL optimizer with regularization.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation],
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||
optimizer=tf.train.FtrlOptimizer(
|
||||
learning_rate=0.1,
|
||||
l1_regularization_strength=0.001
|
||||
@ -289,7 +287,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
|
||||
# Or estimator using the SDCAOptimizer.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation],
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
||||
example_id_column='example_id',
|
||||
num_loss_partitions=...,
|
||||
@ -465,13 +463,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
||||
return preds[head_lib.PredictionKey.CLASSES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -479,14 +480,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||
as_iterable=True):
|
||||
"""Runs inference to determine the class probability predictions."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[
|
||||
head_lib.PredictionKey.PROBABILITIES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def get_variable_names(self):
|
||||
return self._estimator.get_variable_names()
|
||||
@ -512,9 +515,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
input_fn=input_fn or default_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(
|
||||
signature_fn or export.classification_signature_fn_with_prob),
|
||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
||||
signature_fn=(signature_fn or
|
||||
export.classification_signature_fn_with_prob),
|
||||
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
@ -561,16 +564,13 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
Example:
|
||||
|
||||
```python
|
||||
education = sparse_column_with_hash_bucket(column_name="education",
|
||||
hash_bucket_size=1000)
|
||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
||||
hash_bucket_size=1000)
|
||||
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||
|
||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
||||
hash_bucket_size=10000)
|
||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||
|
||||
estimator = LinearRegressor(
|
||||
feature_columns=[occupation, education_x_occupation])
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||
|
||||
# Input builders
|
||||
def input_fn_train: # returns x, y
|
||||
@ -731,13 +731,16 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.SCORES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.SCORES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.SCORES)
|
||||
return preds[head_lib.PredictionKey.SCORES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def get_variable_names(self):
|
||||
return self._estimator.get_variable_names()
|
||||
@ -764,7 +767,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(signature_fn or export.regression_signature_fn),
|
||||
prediction_key=head_lib.PredictionKey.SCORES,
|
||||
prediction_key=prediction_key.PredictionKey.SCORES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for metric keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for model prediction keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class PredictionKey(object):
|
||||
CLASSES = "classes"
|
||||
PROBABILITIES = "probabilities"
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||
|
||||
|
||||
@ -188,13 +189,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
||||
return preds[head_lib.PredictionKey.CLASSES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -202,14 +206,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||
as_iterable=True):
|
||||
"""Runs inference to determine the class probability predictions."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[
|
||||
head_lib.PredictionKey.PROBABILITIES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_variable_names(self):
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
@ -77,7 +78,8 @@ def get_summary_writer(logdir):
|
||||
|
||||
|
||||
def _make_saver(graph, keep_checkpoint_max=5):
|
||||
vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
|
||||
vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
|
||||
graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
||||
if vars_to_save:
|
||||
return tf_saver.Saver(vars_to_save,
|
||||
sharded=True,
|
||||
@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
||||
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
||||
|
||||
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
||||
|
||||
with graph.as_default() as g:
|
||||
with tf_session.Session('') as session:
|
||||
session.run(
|
||||
resources.initialize_resources(resources.shared_resources() +
|
||||
resources.local_resources()))
|
||||
if restore_checkpoint_path:
|
||||
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
||||
else:
|
||||
|
@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
|
||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
|
||||
pass
|
||||
self.assertTrue(request_stop.called)
|
||||
|
||||
def test_run_feeds_iter_calls_resources_init(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
in0, _, _ = self._build_inference_graph()
|
||||
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
|
||||
resources.register_resource(
|
||||
handle=handle,
|
||||
create_op=test_ops.resource_create_op(handle),
|
||||
is_initialized_op=test_ops.resource_initialized_op(handle))
|
||||
|
||||
for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
|
||||
feed_dicts=[{}]):
|
||||
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
|
||||
|
||||
def test_infer_different_default_graph(self):
|
||||
with self.test_session():
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
|
@ -1,7 +1,8 @@
|
||||
### TensorFlow Makefile
|
||||
|
||||
The recommended way to build TensorFlow from source is using the Bazel
|
||||
open-source build system. Sometimes this isn't possible.
|
||||
open-source build system. Sometimes this isn't possible. For example,
|
||||
if you are building for iOS, you currently need to use the Makefile.
|
||||
|
||||
- The build system may not have the RAM or processing power to support Bazel.
|
||||
- Bazel or its dependencies may not be available.
|
||||
|
@ -43,6 +43,13 @@ tensorflow/core/kernels/sequence_ops.cc
|
||||
tensorflow/core/kernels/sendrecv_ops.cc
|
||||
tensorflow/core/kernels/scatter_op.cc
|
||||
tensorflow/core/kernels/scatter_functor.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
|
||||
tensorflow/core/kernels/scatter_nd_op.cc
|
||||
tensorflow/core/kernels/save_restore_tensor.cc
|
||||
tensorflow/core/kernels/save_restore_v2_ops.cc
|
||||
tensorflow/core/kernels/save_op.cc
|
||||
|
@ -763,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
computes the area under a discretized curve of precision versus recall values
|
||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||
controls the degree of discretization with larger numbers of thresholds more
|
||||
closely approximating the true AUC.
|
||||
closely approximating the true AUC. The quality of the approximation may vary
|
||||
dramatically depending on `num_thresholds`.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables and returns the `auc`.
|
||||
|
@ -15,9 +15,16 @@ py_library(
|
||||
"python/training/resample.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
"python/training/training.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -98,6 +105,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_test",
|
||||
size = "large",
|
||||
srcs = ["python/training/training_test.py"],
|
||||
shard_count = 3,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||
from tensorflow.contrib.training.python.training.resample import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
|
||||
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
|
||||
from tensorflow.contrib.training.python.training.training import create_train_op
|
||||
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||
from tensorflow.contrib.training.python.training.training import train
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
316
tensorflow/contrib/training/python/training/training.py
Normal file
316
tensorflow/contrib/training/python/training/training.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains various routines and helper functions for training models.
|
||||
|
||||
TODO(nsilberman): Port documentation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import optimizer as tf_optimizer
|
||||
|
||||
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
|
||||
# multiply_gradients into contrib/summaries and contrib/optimizers.py
|
||||
__all__ = [
|
||||
'add_gradients_summaries',
|
||||
'clip_gradient_norms',
|
||||
'create_train_op',
|
||||
'multiply_gradients',
|
||||
'train',
|
||||
]
|
||||
|
||||
|
||||
def add_gradients_summaries(grads_and_vars):
|
||||
"""Add summaries to gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
|
||||
Returns:
|
||||
The list of created summaries.
|
||||
"""
|
||||
summaries = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
grad_values = grad.values
|
||||
else:
|
||||
grad_values = grad
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient', grad_values))
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
|
||||
else:
|
||||
logging.info('Var %s has no gradient', var.op.name)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def clip_gradient_norms(gradients_to_variables, max_norm):
|
||||
"""Clips the gradients by the given value.
|
||||
|
||||
Args:
|
||||
gradients_to_variables: A list of gradient to variable pairs (tuples).
|
||||
max_norm: the maximum norm value.
|
||||
|
||||
Returns:
|
||||
A list of clipped gradient to variable pairs.
|
||||
"""
|
||||
clipped_grads_and_vars = []
|
||||
for grad, var in gradients_to_variables:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad = clip_ops.clip_by_norm(grad, max_norm)
|
||||
clipped_grads_and_vars.append((grad, var))
|
||||
return clipped_grads_and_vars
|
||||
|
||||
|
||||
def multiply_gradients(grads_and_vars, gradient_multipliers):
|
||||
"""Multiply specified gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
gradient_multipliers: A map from either `Variables` or `Variable` op names
|
||||
to the coefficient by which the associated gradient should be scaled.
|
||||
|
||||
Returns:
|
||||
The updated list of gradient to variable pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
|
||||
is empty or None or if `gradient_multipliers` is not a dictionary.
|
||||
"""
|
||||
if not isinstance(grads_and_vars, list):
|
||||
raise ValueError('`grads_and_vars` must be a list.')
|
||||
if not gradient_multipliers:
|
||||
raise ValueError('`gradient_multipliers` is empty.')
|
||||
if not isinstance(gradient_multipliers, dict):
|
||||
raise ValueError('`gradient_multipliers` must be a dict.')
|
||||
|
||||
multiplied_grads_and_vars = []
|
||||
for grad, var in grads_and_vars:
|
||||
if var in gradient_multipliers or var.op.name in gradient_multipliers:
|
||||
key = var if var in gradient_multipliers else var.op.name
|
||||
if grad is None:
|
||||
raise ValueError('Requested multiple of `None` gradient.')
|
||||
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = grad.values * constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad *= constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
multiplied_grads_and_vars.append((grad, var))
|
||||
return multiplied_grads_and_vars
|
||||
|
||||
|
||||
def create_train_op(total_loss,
|
||||
optimizer,
|
||||
global_step=None,
|
||||
update_ops=None,
|
||||
variables_to_train=None,
|
||||
transform_grads_fn=None,
|
||||
summarize_gradients=False,
|
||||
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
|
||||
aggregation_method=None,
|
||||
colocate_gradients_with_ops=False):
|
||||
"""Creates an `Operation` that evaluates the gradients and returns the loss.
|
||||
|
||||
Args:
|
||||
total_loss: A `Tensor` representing the total loss.
|
||||
optimizer: A tf.Optimizer to use for computing the gradients.
|
||||
global_step: A `Tensor` representing the global step variable. If left as
|
||||
`None`, then slim.variables.global_step() is used.
|
||||
update_ops: An optional list of updates to execute. If `update_ops` is
|
||||
`None`, then the update ops are set to the contents of the
|
||||
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
|
||||
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
|
||||
a warning will be displayed.
|
||||
variables_to_train: an optional list of variables to train. If None, it will
|
||||
default to all tf.trainable_variables().
|
||||
transform_grads_fn: A function which takes a single argument, a list of
|
||||
gradient to variable pairs (tuples), performs any requested gradient
|
||||
updates, such as gradient clipping or multipliers, and returns the updated
|
||||
list.
|
||||
summarize_gradients: Whether or not add summaries for each gradient.
|
||||
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
colocate_gradients_with_ops: Whether or not to try colocating the gradients
|
||||
with the ops that generated them.
|
||||
|
||||
Returns:
|
||||
A `Tensor` that when evaluated, computes the gradients and returns the total
|
||||
loss value.
|
||||
"""
|
||||
if global_step is None:
|
||||
global_step = variables.get_or_create_global_step()
|
||||
|
||||
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
|
||||
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
if update_ops is None:
|
||||
update_ops = global_update_ops
|
||||
else:
|
||||
update_ops = set(update_ops)
|
||||
if not global_update_ops.issubset(update_ops):
|
||||
logging.warning('update_ops in create_train_op does not contain all the '
|
||||
' update_ops in GraphKeys.UPDATE_OPS')
|
||||
|
||||
# Make sure update_ops are computed before total_loss.
|
||||
if update_ops:
|
||||
with ops.control_dependencies(update_ops):
|
||||
barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
|
||||
|
||||
if variables_to_train is None:
|
||||
# Default to tf.trainable_variables()
|
||||
variables_to_train = tf_variables.trainable_variables()
|
||||
else:
|
||||
# Make sure that variables_to_train are in tf.trainable_variables()
|
||||
for v in variables_to_train:
|
||||
assert v in tf_variables.trainable_variables()
|
||||
|
||||
assert variables_to_train
|
||||
|
||||
# Create the gradients. Note that apply_gradients adds the gradient
|
||||
# computation to the current graph.
|
||||
grads = optimizer.compute_gradients(
|
||||
total_loss,
|
||||
variables_to_train,
|
||||
gate_gradients=gate_gradients,
|
||||
aggregation_method=aggregation_method,
|
||||
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||||
|
||||
if transform_grads_fn:
|
||||
grads = transform_grads_fn(grads)
|
||||
|
||||
# Summarize gradients.
|
||||
if summarize_gradients:
|
||||
with ops.name_scope('summarize_grads'):
|
||||
add_gradients_summaries(grads)
|
||||
|
||||
# Create gradient updates.
|
||||
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
with ops.name_scope('train_op'):
|
||||
# Make sure total_loss is valid.
|
||||
total_loss = array_ops.check_numerics(total_loss,
|
||||
'LossTensor is inf or nan')
|
||||
|
||||
# Ensure the train_tensor computes grad_updates.
|
||||
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||
|
||||
|
||||
def train(
|
||||
train_op,
|
||||
logdir,
|
||||
master='',
|
||||
is_chief=True,
|
||||
scaffold=None,
|
||||
hooks=None,
|
||||
chief_only_hooks=None,
|
||||
save_checkpoint_secs=600,
|
||||
save_summaries_steps=100,
|
||||
config=None):
|
||||
"""Runs the training loop.
|
||||
|
||||
Args:
|
||||
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||
return the loss value.
|
||||
logdir: The directory where the graph and checkpoints are saved.
|
||||
master: The URL of the master.
|
||||
is_chief: Specifies whether or not the training is being run by the primary
|
||||
replica during replica training.
|
||||
scaffold: An tf.train.Scaffold instance.
|
||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||
training loop.
|
||||
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||
inside the training loop for the chief trainer only.
|
||||
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||
`None`, then the default checkpoint saver isn't used.
|
||||
save_summaries_steps: The frequency, in number of global steps, that the
|
||||
summaries are written to disk using a default summary saver. If
|
||||
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||
isn't used.
|
||||
config: An instance of `tf.ConfigProto`.
|
||||
|
||||
Returns:
|
||||
the value of the loss function after training.
|
||||
|
||||
Raises:
|
||||
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
|
||||
`save_summaries_steps` are `None.
|
||||
"""
|
||||
# TODO(nsilberman): move this logic into monitored_session.py
|
||||
scaffold = scaffold or monitored_session.Scaffold()
|
||||
|
||||
hooks = hooks or []
|
||||
|
||||
if is_chief:
|
||||
session_creator = monitored_session.ChiefSessionCreator(
|
||||
scaffold=scaffold,
|
||||
checkpoint_dir=logdir,
|
||||
master=master,
|
||||
config=config)
|
||||
|
||||
if chief_only_hooks:
|
||||
hooks.extend(chief_only_hooks)
|
||||
|
||||
hooks.append(basic_session_run_hooks.StepCounterHook(
|
||||
output_dir=logdir))
|
||||
|
||||
if save_summaries_steps:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_summaries_steps is None')
|
||||
hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
output_dir=logdir))
|
||||
|
||||
if save_checkpoint_secs:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_checkpoint_secs is None')
|
||||
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
|
||||
else:
|
||||
session_creator = monitored_session.WorkerSessionCreator(
|
||||
scaffold=scaffold, master=master, config=config)
|
||||
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=session_creator, hooks=hooks) as session:
|
||||
loss = None
|
||||
while not session.should_stop():
|
||||
loss = session.run(train_op)
|
||||
return loss
|
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
@ -0,0 +1,514 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.training.training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def logistic_classifier(inputs):
|
||||
return tf.contrib.layers.fully_connected(
|
||||
inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
def batchnorm_classifier(inputs):
|
||||
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
|
||||
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
class CreateTrainOpTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(0)
|
||||
|
||||
# Create an easy training set:
|
||||
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
def testUseUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
expected_mean = np.mean(self._inputs, axis=(0))
|
||||
expected_var = np.var(self._inputs, axis=(0))
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
# After 10 updates with decay 0.1 moving_mean == expected_mean and
|
||||
# moving_variance == expected_var.
|
||||
self.assertAllClose(mean, expected_mean)
|
||||
self.assertAllClose(variance, expected_var)
|
||||
|
||||
def testEmptyUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, update_ops=[])
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
|
||||
# Since we skip update_ops the moving_vars are not updated.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
|
||||
class TrainBNClassifierTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, self._logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertLess(loss, .1)
|
||||
|
||||
|
||||
class TrainTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testCanAchieveZeroLoss(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainWithLocalVariable(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
local_multiplier = tf.contrib.framework.local_variable(1.0)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testResumeTrainAchievesRoughlyTheSameLoss(self):
|
||||
number_of_steps = [300, 1, 5]
|
||||
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
|
||||
|
||||
for i in range(len(number_of_steps)):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(i)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
|
||||
tf.train.CheckpointSaverHook(
|
||||
logdir, save_steps=50, saver=saver),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(
|
||||
learning_rate=learning_rate)
|
||||
|
||||
def transform_grads_fn(grads):
|
||||
if gradient_multiplier != 1.0:
|
||||
variables = tf.trainable_variables()
|
||||
gradient_multipliers = {var: gradient_multiplier for var in variables}
|
||||
|
||||
with tf.name_scope('multiply_grads'):
|
||||
return tf.contrib.training.multiply_gradients(
|
||||
grads, gradient_multipliers)
|
||||
else:
|
||||
return grads
|
||||
|
||||
return tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
|
||||
|
||||
def testTrainWithInitFromCheckpoint(self):
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
|
||||
|
||||
if tf.gfile.Exists(logdir1): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
# First, train the model one step (make sure the error is high).
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=1),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Next, train the model to convergence.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
# Finally, advance the model a single step and validate that the loss is
|
||||
# still low.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
train_op = self.create_train_op()
|
||||
|
||||
model_variables = tf.all_variables()
|
||||
model_path = os.path.join(logdir1, 'model.ckpt-300')
|
||||
|
||||
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
|
||||
model_path, model_variables)
|
||||
def init_fn(_, session):
|
||||
assign_fn(session)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op,
|
||||
logdir2,
|
||||
scaffold=tf.train.Scaffold(init_fn=init_fn),
|
||||
hooks=[tf.train.StopAtStepHook(num_steps=1)])
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
def ModelLoss(self):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
return tf.contrib.losses.get_total_loss()
|
||||
|
||||
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
|
||||
if tf.gfile.Exists(logdir): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir)
|
||||
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights = tf.contrib.framework.get_variables_by_name('weights')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=weights)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=200),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Next, train the biases of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
biases = tf.contrib.framework.get_variables_by_name('biases')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=biases)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Finally, train both weights and bias to get lower loss.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=400),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights, biases = tf.contrib.framework.get_variables()
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
train_weights = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[weights])
|
||||
train_biases = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[biases])
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
# Get the intial weights and biases values.
|
||||
weights_values, biases_values = sess.run([weights, biases])
|
||||
self.assertGreater(np.linalg.norm(weights_values), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
|
||||
|
||||
# Update weights and biases.
|
||||
loss = sess.run(train_op)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights and biases have been updated.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
weights_values, biases_values = new_weights, new_biases
|
||||
|
||||
# Update only weights.
|
||||
loss = sess.run(train_weights)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights have been updated, but biases have not.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
|
||||
weights_values = new_weights
|
||||
|
||||
# Update only biases.
|
||||
loss = sess.run(train_biases)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the biases have been updated, but weights have not.
|
||||
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
def testTrainWithAlteredGradients(self):
|
||||
# Use the same learning rate but different gradient multipliers
|
||||
# to train two models. Model with equivalently larger learning
|
||||
# rate (i.e., learning_rate * gradient_multiplier) has smaller
|
||||
# training loss.
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
|
||||
|
||||
if tf.gfile.Exists(logdir1):
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2):
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
multipliers = [1., 1000.]
|
||||
number_of_steps = 10
|
||||
losses = []
|
||||
learning_rate = 0.001
|
||||
|
||||
# First, train the model with equivalently smaller learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[0])
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Second, train the model with equivalently larger learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[1])
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir2, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .5)
|
||||
|
||||
# The loss of the model trained with larger learning rate should
|
||||
# be smaller.
|
||||
self.assertGreater(losses[0], losses[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -221,13 +221,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
hdrs = ["lib/jpeg/jpeg_mem.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jpeg_internal"],
|
||||
)
|
||||
|
||||
# Test support library needed for all tests
|
||||
# This is currently public, but may be made internal in the
|
||||
# future. Try to avoid depending on it.
|
||||
@ -699,9 +692,9 @@ filegroup(
|
||||
"platform/cuda.h",
|
||||
"platform/google/**/*",
|
||||
"platform/hadoop/**/*",
|
||||
"platform/jpeg.*",
|
||||
"platform/png.*",
|
||||
"platform/gif.*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/png.h",
|
||||
"platform/stream_executor.*",
|
||||
"platform/windows/**/*",
|
||||
"user_ops/**/*.cu.cc",
|
||||
@ -981,7 +974,10 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/load_library.cc",
|
||||
@ -998,7 +994,10 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
],
|
||||
@ -1016,7 +1015,6 @@ cc_library(
|
||||
hdrs = tf_additional_lib_hdrs() + [
|
||||
"lib/core/blocking_counter.h",
|
||||
"lib/core/refcount.h",
|
||||
"lib/gif/gif_io.h",
|
||||
"lib/gtl/edit_distance.h",
|
||||
"lib/gtl/int_type.h",
|
||||
"lib/gtl/iterator_range.h",
|
||||
@ -1060,18 +1058,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gif_internal",
|
||||
srcs = [
|
||||
"lib/gif/gif_io.cc",
|
||||
"platform/gif.h",
|
||||
],
|
||||
hdrs = ["lib/gif/gif_io.h"],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":lib",
|
||||
"//tensorflow/core/platform/default/build_config:gif",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg_internal",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/jpeg/*h",
|
||||
"lib/jpeg/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
],
|
||||
),
|
||||
hdrs = ["lib/jpeg/jpeg_handle.h"],
|
||||
srcs = [
|
||||
"lib/jpeg/jpeg_handle.cc",
|
||||
"lib/jpeg/jpeg_mem.cc",
|
||||
"platform/jpeg.h",
|
||||
],
|
||||
hdrs = [
|
||||
"lib/jpeg/jpeg_handle.h",
|
||||
"lib/jpeg/jpeg_mem.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
@ -1541,7 +1553,6 @@ cc_test(
|
||||
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
|
||||
data = glob(["lib/jpeg/testdata/*.jpg"]),
|
||||
deps = [
|
||||
":jpeg",
|
||||
":jpeg_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
|
@ -78,7 +78,7 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
|
||||
Status DeviceFactory::AddDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
// CPU first.
|
||||
// CPU first. A CPU device is required.
|
||||
auto cpu_factory = GetFactory("CPU");
|
||||
if (!cpu_factory) {
|
||||
return errors::NotFound(
|
||||
@ -90,18 +90,11 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
|
||||
return errors::NotFound("No CPU devices are available in this process");
|
||||
}
|
||||
|
||||
// Then GPU.
|
||||
auto gpu_factory = GetFactory("GPU");
|
||||
if (gpu_factory) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
gpu_factory->CreateDevices(options, name_prefix, devices));
|
||||
}
|
||||
|
||||
// Then the rest.
|
||||
// Then the rest (including GPU).
|
||||
mutex_lock l(*get_device_factory_lock());
|
||||
for (auto& p : device_factories()) {
|
||||
auto factory = p.second.factory.get();
|
||||
if (factory != cpu_factory && factory != gpu_factory) {
|
||||
if (factory != cpu_factory) {
|
||||
TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
|
||||
}
|
||||
}
|
||||
|
@ -282,6 +282,7 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
return;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
SchedClosure([session, req, resp, done]() {
|
||||
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
|
||||
@ -290,7 +291,22 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
||||
}
|
||||
done(status);
|
||||
});
|
||||
}
|
||||
|
||||
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
PartialRunSetupResponse* resp, MyClosure done) {
|
||||
mu_.lock();
|
||||
MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
|
||||
if (session == nullptr) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
return;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
SchedClosure([this, session, req, resp, done]() {
|
||||
done(session->PartialRunSetup(req, resp));
|
||||
});
|
||||
}
|
||||
|
||||
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
@ -303,6 +319,7 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
return;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
SchedClosure([this, start_time, session, opts, req, resp, done]() {
|
||||
Status status = session->Run(opts, req, resp);
|
||||
@ -312,7 +329,6 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
|
||||
++step_count_;
|
||||
});
|
||||
mu_.unlock();
|
||||
}
|
||||
|
||||
void Master::CloseSession(const CloseSessionRequest* req,
|
||||
|
@ -46,6 +46,9 @@ class Master {
|
||||
void ExtendSession(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp, MyClosure done);
|
||||
|
||||
void PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
PartialRunSetupResponse* resp, MyClosure done);
|
||||
|
||||
void RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp, MyClosure done);
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
|
||||
@ -37,6 +38,12 @@ class MasterInterface {
|
||||
const ExtendSessionRequest* request,
|
||||
ExtendSessionResponse* response) = 0;
|
||||
|
||||
virtual Status PartialRunSetup(CallOptions* call_options,
|
||||
const PartialRunSetupRequest* request,
|
||||
PartialRunSetupResponse* response) {
|
||||
return errors::Unimplemented("Partial run not implemented for this master");
|
||||
}
|
||||
|
||||
virtual Status RunStep(CallOptions* call_options,
|
||||
const RunStepRequest* request,
|
||||
RunStepResponse* response) = 0;
|
||||
|
@ -50,18 +50,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A little bit of per-step state.
|
||||
struct PerStepState {
|
||||
bool collect_costs = false;
|
||||
bool collect_timeline = false;
|
||||
bool collect_rpcs = false;
|
||||
Microseconds start_micros = Microseconds(0);
|
||||
Microseconds end_micros = Microseconds(0);
|
||||
std::vector<StepStats> step_stats; // per partition
|
||||
StepStats rpc_stats; // for RPC layer
|
||||
CostGraphDef cost_graph;
|
||||
};
|
||||
|
||||
// MasterSession wraps SimpleClientGraph in a reference counted object.
|
||||
// This way, MasterSession can clear up the cache mapping Run requests to
|
||||
// compiled graphs while the compiled graph is still being used.
|
||||
@ -72,15 +60,38 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
|
||||
std::unique_ptr<SimpleClientGraph> cg,
|
||||
const SessionOptions& session_opts,
|
||||
StatsPublisherFactory stats_publisher_factory)
|
||||
StatsPublisherFactory stats_publisher_factory,
|
||||
SimpleGraphExecutionState* execution_state, bool is_partial)
|
||||
: session_handle_(handle),
|
||||
client_graph_(std::move(cg)),
|
||||
bopts_(bopts),
|
||||
session_opts_(session_opts) {
|
||||
session_opts_(session_opts),
|
||||
is_partial_(is_partial) {
|
||||
VLOG(1) << "Created ReffedClientGraph for node with "
|
||||
<< client_graph_->graph.num_node_ids();
|
||||
|
||||
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
|
||||
|
||||
// If this is a partial run we need to initialize a name to node map for
|
||||
// testing that fetches are reachable.
|
||||
if (is_partial) {
|
||||
std::unordered_set<StringPiece, StringPiece::Hasher> names;
|
||||
for (const string& input : bopts.feed_endpoints) {
|
||||
TensorId id(ParseTensorName(input));
|
||||
names.emplace(id.first);
|
||||
}
|
||||
for (const string& output : bopts.fetch_endpoints) {
|
||||
TensorId id(ParseTensorName(output));
|
||||
names.emplace(id.first);
|
||||
}
|
||||
// We use the graph from the execution_state because we want the graph
|
||||
// nodes before they are rewritten replaced by the rewriter.
|
||||
for (Node* n : execution_state->full_graph()->nodes()) {
|
||||
if (names.count(n->name()) > 0) {
|
||||
name_to_node_.insert({n->name(), n});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~ReffedClientGraph() override { DeregisterPartitions(); }
|
||||
@ -171,7 +182,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
SimpleGraphExecutionState* execution_state,
|
||||
PerStepState* pss, CallOptions* opts,
|
||||
const RunStepRequest& req, RunStepResponse* resp,
|
||||
CancellationManager* cm);
|
||||
CancellationManager* cm, const bool is_last_partial_run);
|
||||
|
||||
// Calls workers to cleanup states for the step "step_id". Calls
|
||||
// `done` when all cleanup RPCs have completed.
|
||||
@ -185,6 +196,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
void ProcessDeviceStats(ProfileHandler* ph,
|
||||
const SimpleGraphExecutionState* execution_state,
|
||||
const DeviceStepStats& ds, bool is_rpc);
|
||||
// Checks that the requested fetches can be computed from the provided feeds.
|
||||
Status CheckFetches(const RunStepRequest& req, const RunState* run_state,
|
||||
SimpleGraphExecutionState* execution_state);
|
||||
|
||||
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
|
||||
int64 tot = 0;
|
||||
@ -209,6 +223,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
std::unordered_set<const Node*> nodes_needing_input_mapping_;
|
||||
BuildGraphOptions bopts_;
|
||||
const SessionOptions session_opts_;
|
||||
const bool is_partial_;
|
||||
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node_;
|
||||
|
||||
// Graph partitioned into per-location subgraphs.
|
||||
struct Part {
|
||||
@ -483,15 +499,14 @@ class RunManyGraphs {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
|
||||
};
|
||||
|
||||
|
||||
Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
const MasterEnv* env, int64 step_id, int64 execution_count,
|
||||
SimpleGraphExecutionState* execution_state, PerStepState* pss,
|
||||
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
|
||||
CancellationManager* cm) {
|
||||
CancellationManager* cm, const bool is_last_partial_run) {
|
||||
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
|
||||
<< execution_count;
|
||||
// Builds an index for feeds provided by the client.
|
||||
// Build an index for feeds provided by the client.
|
||||
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
|
||||
feeds(3);
|
||||
|
||||
@ -524,26 +539,64 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
RunManyGraphs::Call* c = calls.get(i);
|
||||
if (is_partial_) {
|
||||
c->req.set_is_partial(is_partial_);
|
||||
c->req.set_is_last_partial_run(is_last_partial_run);
|
||||
}
|
||||
c->req.set_graph_handle(part.graph_handle);
|
||||
c->req.set_step_id(step_id);
|
||||
*c->req.mutable_exec_opts() = exec_opts;
|
||||
// If any feeds are provided, send the feed values together
|
||||
// in the RunGraph request.
|
||||
for (const auto& feed_key : part.feed_key) {
|
||||
const string& feed = feed_key.first;
|
||||
const string& key = feed_key.second;
|
||||
const TensorProto* val = feeds[feed];
|
||||
if (val == nullptr) {
|
||||
return errors::InvalidArgument("No feed is provided for feed=", feed,
|
||||
", key=", key);
|
||||
// In the partial case, we only want to include feeds provided in the req.
|
||||
// In the non-partial case, all feeds in the request are in the part.
|
||||
// We keep these as separate paths for now, to ensure we aren't
|
||||
// inadvertently slowing down the normal run path.
|
||||
if (is_partial_) {
|
||||
for (const auto& feed : req.feed()) {
|
||||
const string& name = feed.name();
|
||||
auto iter = part.feed_key.find(name);
|
||||
if (iter == part.feed_key.end()) {
|
||||
// The provided feed must be for a different partition.
|
||||
continue;
|
||||
}
|
||||
const string& key = iter->second;
|
||||
const TensorProto* val = feeds[name];
|
||||
if (val == nullptr) {
|
||||
return errors::InvalidArgument("No feed is provided for feed=", name,
|
||||
", key=", key);
|
||||
}
|
||||
auto* send = c->req.add_send();
|
||||
send->set_key(key);
|
||||
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||
}
|
||||
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
|
||||
// For now, we just iterate through partitions to find the matching key.
|
||||
for (const auto& req_fetch : req.fetch()) {
|
||||
for (const auto& key_fetch : part.key_fetch) {
|
||||
if (key_fetch.second == req_fetch) {
|
||||
c->req.add_recv_key(key_fetch.first);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto& feed_key : part.feed_key) {
|
||||
const string& feed = feed_key.first;
|
||||
const string& key = feed_key.second;
|
||||
const TensorProto* val = feeds[feed];
|
||||
if (val == nullptr) {
|
||||
return errors::InvalidArgument("No feed is provided for feed=", feed,
|
||||
", key=", key);
|
||||
}
|
||||
auto* send = c->req.add_send();
|
||||
send->set_key(key);
|
||||
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||
}
|
||||
for (const auto& key_fetch : part.key_fetch) {
|
||||
const string& key = key_fetch.first;
|
||||
c->req.add_recv_key(key);
|
||||
}
|
||||
auto* send = c->req.add_send();
|
||||
send->set_key(key);
|
||||
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||
}
|
||||
for (const auto& key_fetch : part.key_fetch) {
|
||||
const string& key = key_fetch.first;
|
||||
c->req.add_recv_key(key);
|
||||
}
|
||||
}
|
||||
|
||||
@ -762,6 +815,64 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(suharshs): Merge with CheckFetches in DirectSession.
|
||||
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
|
||||
// on once at setup time to prevent us from computing the dependencies
|
||||
// everytime.
|
||||
Status MasterSession::ReffedClientGraph::CheckFetches(
|
||||
const RunStepRequest& req, const RunState* run_state,
|
||||
SimpleGraphExecutionState* execution_state) {
|
||||
// Build the set of pending feeds that we haven't seen.
|
||||
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
|
||||
for (const string& feed : run_state->pending_inputs) {
|
||||
TensorId id(ParseTensorName(feed));
|
||||
auto it = name_to_node_.find(id.first);
|
||||
if (it == name_to_node_.end()) {
|
||||
return errors::NotFound("Feed ", feed, ": not found");
|
||||
}
|
||||
pending_feeds.insert(id);
|
||||
}
|
||||
for (const auto& feed : req.feed()) {
|
||||
TensorId id(ParseTensorName(feed.name()));
|
||||
pending_feeds.erase(id);
|
||||
}
|
||||
|
||||
// Initialize the stack with the fetch nodes.
|
||||
std::vector<const Node*> stack;
|
||||
for (const string& fetch : req.fetch()) {
|
||||
TensorId id(ParseTensorName(fetch));
|
||||
auto it = name_to_node_.find(id.first);
|
||||
if (it == name_to_node_.end()) {
|
||||
return errors::NotFound("Fetch ", fetch, ": not found");
|
||||
}
|
||||
stack.push_back(it->second);
|
||||
}
|
||||
|
||||
// Any tensor needed for fetches can't be in pending_feeds.
|
||||
// We need to use the original full graph from execution state.
|
||||
const Graph* graph = execution_state->full_graph();
|
||||
std::vector<bool> visited(graph->num_node_ids(), false);
|
||||
while (!stack.empty()) {
|
||||
const Node* n = stack.back();
|
||||
stack.pop_back();
|
||||
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
const Node* in_node = in_edge->src();
|
||||
if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
|
||||
return errors::InvalidArgument("Fetch ", in_node->name(), ":",
|
||||
in_edge->src_output(),
|
||||
" can't be computed from the feeds"
|
||||
" that have been fed so far.");
|
||||
}
|
||||
if (!visited[in_node->id()]) {
|
||||
visited[in_node->id()] = true;
|
||||
stack.push_back(in_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Asynchronously deregisters subgraphs on the workers, without waiting for the
|
||||
// result.
|
||||
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
|
||||
@ -803,6 +914,23 @@ void BuildBuildGraphOptions(const RunStepRequest& req,
|
||||
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
||||
}
|
||||
|
||||
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
|
||||
BuildGraphOptions* opts) {
|
||||
for (const auto& feed : req.feed()) {
|
||||
opts->feed_endpoints.push_back(feed);
|
||||
}
|
||||
for (const auto& fetch : req.fetch()) {
|
||||
opts->fetch_endpoints.push_back(fetch);
|
||||
}
|
||||
for (const auto& target : req.target()) {
|
||||
opts->target_nodes.push_back(target);
|
||||
}
|
||||
|
||||
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
|
||||
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
|
||||
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
||||
}
|
||||
|
||||
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
||||
uint64 h = 0x2b992ddfa23249d6ull;
|
||||
for (const string& name : opts.feed_endpoints) {
|
||||
@ -927,11 +1055,9 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MasterSession::StartStep(const RunStepRequest& req,
|
||||
BuildGraphOptions* opts, int64* count,
|
||||
ReffedClientGraph** rcg) {
|
||||
BuildBuildGraphOptions(req, opts);
|
||||
const uint64 hash = HashBuildGraphOptions(*opts);
|
||||
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
|
||||
ReffedClientGraph** rcg, bool is_partial) {
|
||||
const uint64 hash = HashBuildGraphOptions(opts);
|
||||
ReffedClientGraph* to_unref = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -944,12 +1070,12 @@ Status MasterSession::StartStep(const RunStepRequest& req,
|
||||
// We have not seen this subgraph before. Build the subgraph and
|
||||
// cache it.
|
||||
VLOG(1) << "Unseen hash " << hash << " for "
|
||||
<< BuildGraphOptionsString(*opts);
|
||||
<< BuildGraphOptionsString(opts);
|
||||
std::unique_ptr<SimpleClientGraph> client_graph;
|
||||
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
|
||||
auto entry =
|
||||
new ReffedClientGraph(handle_, *opts, std::move(client_graph),
|
||||
session_opts_, stats_publisher_factory_);
|
||||
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
|
||||
auto entry = new ReffedClientGraph(
|
||||
handle_, opts, std::move(client_graph), session_opts_,
|
||||
stats_publisher_factory_, execution_state_.get(), is_partial);
|
||||
iter = runs_.insert({hash, entry}).first;
|
||||
auto obs_iter = obsolete_.find(hash);
|
||||
if (obs_iter != obsolete_.end()) {
|
||||
@ -979,6 +1105,47 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||
rcg_map->clear();
|
||||
}
|
||||
|
||||
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
PartialRunSetupResponse* resp) {
|
||||
std::vector<string> inputs, outputs, targets;
|
||||
for (const auto& feed : req->feed()) {
|
||||
inputs.push_back(feed);
|
||||
}
|
||||
for (const auto& fetch : req->fetch()) {
|
||||
outputs.push_back(fetch);
|
||||
}
|
||||
for (const auto& target : req->target()) {
|
||||
targets.push_back(target);
|
||||
}
|
||||
|
||||
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
|
||||
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
int64 count = 0;
|
||||
|
||||
// Prepare.
|
||||
BuildGraphOptions opts;
|
||||
BuildBuildGraphOptions(*req, &opts);
|
||||
TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
|
||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||
// step_id for future use.
|
||||
uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
rcg->Ref();
|
||||
RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
partial_runs_.emplace(
|
||||
std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
|
||||
|
||||
resp->set_partial_run_handle(handle);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
UpdateLastAccessTime();
|
||||
@ -986,7 +1153,12 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
||||
mutex_lock l(mu_);
|
||||
++num_running_;
|
||||
}
|
||||
Status status = DoRunWithLocalExecution(opts, req, resp);
|
||||
Status status;
|
||||
if (!req->partial_run_handle().empty()) {
|
||||
status = DoPartialRun(opts, req, resp);
|
||||
} else {
|
||||
status = DoRunWithLocalExecution(opts, req, resp);
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
--num_running_;
|
||||
@ -997,23 +1169,7 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
||||
return status;
|
||||
}
|
||||
|
||||
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
VLOG(2) << "DoRunWithLocalExecution "
|
||||
<< "req: " << req->DebugString();
|
||||
PerStepState pss;
|
||||
pss.start_micros = Env::Default()->NowMicros();
|
||||
|
||||
// Prepare.
|
||||
BuildGraphOptions bgopts;
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
int64 count = 0;
|
||||
TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
|
||||
|
||||
// Unref "rcg" when out of scope.
|
||||
core::ScopedUnref unref(rcg);
|
||||
|
||||
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||
// Registers subgraphs if haven't done so.
|
||||
PartitionOptions popts;
|
||||
popts.node_to_loc = SplitByWorker;
|
||||
@ -1051,12 +1207,136 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(
|
||||
env_, popts, rcg->client_graph()->flib_def->ToProto()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
const string& prun_handle = req->partial_run_handle();
|
||||
RunState* run_state = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto it = partial_runs_.find(prun_handle);
|
||||
if (it == partial_runs_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Must run PartialRunSetup before performing partial runs");
|
||||
}
|
||||
run_state = it->second.get();
|
||||
}
|
||||
|
||||
// If this is the first partial run, initialize the PerStepState.
|
||||
if (!run_state->step_started) {
|
||||
run_state->step_started = true;
|
||||
PerStepState pss;
|
||||
|
||||
auto count = run_state->count;
|
||||
pss.collect_timeline =
|
||||
req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||
|
||||
// Build the cost model every 'build_cost_model_every' steps after skipping
|
||||
// an
|
||||
// initial 'build_cost_model_after' steps.
|
||||
const int64 build_cost_model_after =
|
||||
session_opts_.config.graph_options().build_cost_model_after();
|
||||
const int64 build_cost_model_every =
|
||||
session_opts_.config.graph_options().build_cost_model();
|
||||
pss.collect_costs =
|
||||
build_cost_model_every > 0 &&
|
||||
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
||||
|
||||
std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
|
||||
run_state->step_id, count, req->options());
|
||||
if (ph) {
|
||||
pss.collect_timeline = true;
|
||||
pss.collect_rpcs = ph->should_collect_rpcs();
|
||||
}
|
||||
|
||||
run_state->pss = std::move(pss);
|
||||
run_state->ph = std::move(ph);
|
||||
}
|
||||
|
||||
// Make sure that this is a new set of feeds that are still pending.
|
||||
for (const auto& feed : req->feed()) {
|
||||
auto it = run_state->pending_inputs.find(feed.name());
|
||||
if (it == run_state->pending_inputs.end()) {
|
||||
return errors::InvalidArgument("The feed ", feed.name(),
|
||||
" had already been fed.");
|
||||
}
|
||||
}
|
||||
// Check that this is a new set of fetches that are still pending.
|
||||
for (const auto& fetch : req->fetch()) {
|
||||
auto it = run_state->pending_outputs.find(fetch);
|
||||
if (it == run_state->pending_outputs.end()) {
|
||||
return errors::InvalidArgument("The fetch ", fetch,
|
||||
" had already been fetched.");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the requested fetches can be computed from the provided feeds.
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_state->rcg->CheckFetches(*req, run_state, execution_state_.get()));
|
||||
|
||||
// Determine if this partial run satisfies all the pending inputs and ouputs.
|
||||
for (const auto& feed : req->feed()) {
|
||||
run_state->pending_inputs.erase(feed.name());
|
||||
}
|
||||
for (const auto& fetch : req->fetch()) {
|
||||
run_state->pending_outputs.erase(fetch);
|
||||
}
|
||||
bool is_last_partial_run =
|
||||
(run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
|
||||
|
||||
Status s = run_state->rcg->RunPartitions(
|
||||
env_, run_state->step_id, run_state->count, execution_state_.get(),
|
||||
&run_state->pss, opts, *req, resp, cancellation_manager_,
|
||||
is_last_partial_run);
|
||||
|
||||
// Delete the run state if there is an error or all fetches are done.
|
||||
if (!s.ok() || is_last_partial_run) {
|
||||
ReffedClientGraph* rcg = run_state->rcg;
|
||||
run_state->pss.end_micros = Env::Default()->NowMicros();
|
||||
// Schedule post-processing and cleanup to be done asynchronously.
|
||||
rcg->Ref();
|
||||
rcg->ProcessStats(env_, run_state->step_id, &run_state->pss,
|
||||
execution_state_.get(), run_state->ph.get(), *req, resp);
|
||||
rcg->CleanupPartitionsAsync(
|
||||
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||
}
|
||||
rcg->Unref();
|
||||
mutex_lock l(mu_);
|
||||
partial_runs_.erase(prun_handle);
|
||||
});
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
VLOG(2) << "DoRunWithLocalExecution "
|
||||
<< "req: " << req->DebugString();
|
||||
PerStepState pss;
|
||||
pss.start_micros = Env::Default()->NowMicros();
|
||||
|
||||
// Prepare.
|
||||
BuildGraphOptions bgopts;
|
||||
BuildBuildGraphOptions(*req, &bgopts);
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
int64 count = 0;
|
||||
TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
|
||||
|
||||
// Unref "rcg" when out of scope.
|
||||
core::ScopedUnref unref(rcg);
|
||||
|
||||
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
|
||||
|
||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||
// step_id for future use.
|
||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
std::unique_ptr<ProfileHandler> ph;
|
||||
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||
|
||||
// Build the cost model every 'build_cost_model_every' steps after skipping an
|
||||
@ -1069,15 +1349,16 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
build_cost_model_every > 0 &&
|
||||
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
||||
|
||||
ph = rcg->GetProfileHandler(step_id, count, req->options());
|
||||
std::unique_ptr<ProfileHandler> ph =
|
||||
rcg->GetProfileHandler(step_id, count, req->options());
|
||||
if (ph) {
|
||||
pss.collect_timeline = true;
|
||||
pss.collect_rpcs = ph->should_collect_rpcs();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
|
||||
execution_state_.get(), &pss, opts,
|
||||
*req, resp, cancellation_manager_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
rcg->RunPartitions(env_, step_id, count, execution_state_.get(), &pss,
|
||||
opts, *req, resp, cancellation_manager_, false));
|
||||
|
||||
pss.end_micros = Env::Default()->NowMicros();
|
||||
|
||||
@ -1110,4 +1391,22 @@ Status MasterSession::Close() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
MasterSession::RunState::RunState(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
ReffedClientGraph* rcg, const uint64 step_id,
|
||||
const int64 count)
|
||||
: rcg(rcg), step_id(step_id), count(count) {
|
||||
// Initially all the feeds and fetches are pending.
|
||||
for (auto& name : input_names) {
|
||||
pending_inputs.emplace(name);
|
||||
}
|
||||
for (auto& name : output_names) {
|
||||
pending_outputs.emplace(name);
|
||||
}
|
||||
}
|
||||
|
||||
MasterSession::RunState::~RunState() {
|
||||
if (rcg) rcg->Unref();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
@ -72,6 +73,10 @@ class MasterSession {
|
||||
// Extend() may block the caller thread for a long time.
|
||||
Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
|
||||
|
||||
// Setup a partial run call.
|
||||
Status PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
PartialRunSetupResponse* resp);
|
||||
|
||||
// Run one step.
|
||||
Status Run(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp);
|
||||
@ -101,6 +106,8 @@ class MasterSession {
|
||||
|
||||
std::atomic_ulong last_access_time_usec_;
|
||||
|
||||
std::atomic<int64> partial_run_handle_counter_ = {0};
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
|
||||
int64 graph_version_;
|
||||
@ -115,6 +122,36 @@ class MasterSession {
|
||||
RCGMap runs_ GUARDED_BY(mu_);
|
||||
RCGMap obsolete_ GUARDED_BY(mu_);
|
||||
|
||||
struct PerStepState {
|
||||
bool collect_costs = false;
|
||||
bool collect_timeline = false;
|
||||
bool collect_rpcs = false;
|
||||
Microseconds start_micros = Microseconds(0);
|
||||
Microseconds end_micros = Microseconds(0);
|
||||
std::vector<StepStats> step_stats; // per partition
|
||||
StepStats rpc_stats; // for RPC layer
|
||||
CostGraphDef cost_graph;
|
||||
};
|
||||
|
||||
struct RunState {
|
||||
std::unordered_set<string> pending_inputs;
|
||||
std::unordered_set<string> pending_outputs;
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
uint64 step_id;
|
||||
int64 count = 0;
|
||||
PerStepState pss;
|
||||
std::unique_ptr<ProfileHandler> ph;
|
||||
bool step_started = false;
|
||||
|
||||
RunState(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names, ReffedClientGraph* rcg,
|
||||
const uint64 step_id, const int64 count);
|
||||
|
||||
~RunState();
|
||||
};
|
||||
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
// Active RunStep calls.
|
||||
condition_variable num_running_is_zero_;
|
||||
int32 num_running_ GUARDED_BY(mu_) = 0;
|
||||
@ -131,14 +168,18 @@ class MasterSession {
|
||||
// Private dtor. The client must call Close().
|
||||
virtual ~MasterSession();
|
||||
|
||||
Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
|
||||
int64* count, ReffedClientGraph** graph);
|
||||
Status StartStep(const BuildGraphOptions& opts, int64* count,
|
||||
ReffedClientGraph** graph, bool is_partial);
|
||||
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp);
|
||||
Status DoPartialRun(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp);
|
||||
void UpdateLastAccessTime();
|
||||
|
||||
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
|
||||
};
|
||||
|
||||
|
@ -131,7 +131,7 @@ class OpKernel {
|
||||
// We allow legacy scalars within Google up until GraphDef version 6.
|
||||
// TODO(irving): Remove when we can drop support for GraphDef version 5.
|
||||
bool allow_legacy_scalars() const {
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
|
||||
return graph_def_version_ < 6;
|
||||
#else
|
||||
return false;
|
||||
|
@ -1136,8 +1136,9 @@ tf_kernel_libraries(
|
||||
":eigen_helpers",
|
||||
":image_resizer_state",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gif_internal",
|
||||
"//tensorflow/core:image_ops_op_lib",
|
||||
"//tensorflow/core:jpeg",
|
||||
"//tensorflow/core:jpeg_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -2099,11 +2100,13 @@ tf_kernel_libraries(
|
||||
"count_up_to_op",
|
||||
"dense_update_ops",
|
||||
"scatter_op",
|
||||
"scatter_nd_op",
|
||||
"variable_ops",
|
||||
],
|
||||
deps = [
|
||||
":assign_op",
|
||||
":bounds_check",
|
||||
":fill_functor",
|
||||
":scatter_functor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2117,6 +2120,7 @@ tf_cc_test(
|
||||
size = "small",
|
||||
srcs = ["scatter_op_test.cc"],
|
||||
deps = [
|
||||
":fill_functor",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":scatter_op",
|
||||
@ -2129,6 +2133,23 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "scatter_nd_op_test",
|
||||
size = "small",
|
||||
srcs = ["scatter_nd_op_test.cc"],
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":scatter_nd_op",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_libraries(
|
||||
name = "string",
|
||||
prefixes = [
|
||||
@ -2571,6 +2592,7 @@ filegroup(
|
||||
"debug_ops.*",
|
||||
# Ops excluded because they do not build correctly for Android.
|
||||
# See b/29213790
|
||||
"scatter_nd_op*",
|
||||
"sparse_matmul_op.*",
|
||||
],
|
||||
),
|
||||
|
@ -64,7 +64,7 @@ class TestFileSystem : public NullFileSystem {
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
|
||||
float val = 0;
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(fname, &scheme, &host, &path);
|
||||
io::ParseURI(fname, &scheme, &host, &path);
|
||||
// For the tests create in-memory regions with float values equal to the
|
||||
// region name.
|
||||
if (path == "/2") {
|
||||
|
@ -46,25 +46,6 @@ namespace functor {
|
||||
using random::PhiloxRandom;
|
||||
using random::SingleSampleAdapter;
|
||||
|
||||
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
||||
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
||||
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
||||
//
|
||||
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
||||
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
||||
// both tails may have reduced accuracy.
|
||||
template <typename Device, typename T>
|
||||
struct TruncatedNormalFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
typename TTypes<T>::ConstFlat means,
|
||||
typename TTypes<T>::ConstFlat stddevs,
|
||||
typename TTypes<T>::ConstFlat minvals,
|
||||
typename TTypes<T>::ConstFlat maxvals,
|
||||
const random::PhiloxRandom& gen,
|
||||
typename TTypes<T>::Flat output);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
static const int kMaxIterations = 100;
|
||||
@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
|
||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||
// We always generate at most 4 samples.
|
||||
tensorflow::random::Array<T, 4> z;
|
||||
tensorflow::random::Array<T, 4> g;
|
||||
Eigen::array<T, 4> z;
|
||||
Eigen::array<T, 4> g;
|
||||
|
||||
for (int64 b = start_batch; b < limit_batch; ++b) {
|
||||
// We are passed a flat array for each of the parameter tensors.
|
||||
@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
if (diff < cutoff) {
|
||||
// Sample from a uniform distribution on [normMin, normMax].
|
||||
|
||||
T plusFactor;
|
||||
if (normMin < T(0)) {
|
||||
// normMax > 0 because it is flipped otherwise.
|
||||
plusFactor = T(0);
|
||||
} else {
|
||||
plusFactor = normMin * normMin;
|
||||
}
|
||||
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||
|
||||
while (sample < limit_sample) {
|
||||
const auto rand = dist(&gen_copy);
|
||||
@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER);
|
||||
TF_CALL_float(REGISTER);
|
||||
TF_CALL_double(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -16,14 +16,35 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
||||
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
||||
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
||||
//
|
||||
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
||||
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
||||
// both tails may have reduced accuracy.
|
||||
template <typename Device, typename T>
|
||||
struct TruncatedNormalFunctor;
|
||||
struct TruncatedNormalFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
typename TTypes<T>::ConstFlat means,
|
||||
typename TTypes<T>::ConstFlat stddevs,
|
||||
typename TTypes<T>::ConstFlat minvals,
|
||||
typename TTypes<T>::ConstFlat maxvals,
|
||||
const random::PhiloxRandom& gen,
|
||||
typename TTypes<T>::Flat output);
|
||||
|
||||
static const int kMaxIterations = 100;
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,214 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
#define UNROLL _Pragma("unroll")
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T>
|
||||
__global__ void __launch_bounds__(1024)
|
||||
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
const T* means, bool single_mean, const T* stddevs,
|
||||
bool single_stddev, const T* minvals,
|
||||
bool single_minval, const T* maxvals,
|
||||
bool single_maxval, int64 kMaxIterations) {
|
||||
const int32 max_samples_per_item = 2 * kMaxIterations;
|
||||
// Initial offset as given by CUDA_1D_KERNEL_LOOP.
|
||||
const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
gen.Skip(max_samples_per_item * initial_offset);
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
|
||||
Uniform dist;
|
||||
const int kDistSize = Uniform::kResultElementCount;
|
||||
const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
|
||||
|
||||
// We skip the total number of threads to get to the next element. To produce
|
||||
// deterministic results between devices, each element in the output array
|
||||
// skips max_samples_per_item in the generator. Then after generating this
|
||||
// item, we need to skip the samples for one element for every thread to get
|
||||
// to the next element that we actually process.
|
||||
const int32 samples_between_processed_elements =
|
||||
max_samples_per_item * (gridDim.x * blockDim.x);
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(offset, num_elements) {
|
||||
// Track how many more samples we need to skip before we process the next
|
||||
// element.
|
||||
int32 remaining_samples = samples_between_processed_elements;
|
||||
|
||||
const int64 batch_id = offset / samples_per_batch;
|
||||
T mean = means[single_mean ? 0 : batch_id];
|
||||
const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
|
||||
T minval = minvals[single_minval ? 0 : batch_id];
|
||||
T maxval = maxvals[single_maxval ? 0 : batch_id];
|
||||
|
||||
// Flip the distribution if we can make the lower bound positive.
|
||||
T stddev;
|
||||
if (Eigen::numext::isinf(minval) || maxval < mean) {
|
||||
// Reverse all calculations. normMin and normMax will be flipped.
|
||||
// std::swap is a host function (not available in CUDA).
|
||||
T temp = minval;
|
||||
minval = maxval;
|
||||
maxval = temp;
|
||||
stddev = -input_stddev;
|
||||
} else {
|
||||
stddev = input_stddev;
|
||||
}
|
||||
|
||||
// Calculate normalized samples, then scale them.
|
||||
const T normMin = (minval - mean) / stddev;
|
||||
const T normMax = (maxval - mean) / stddev;
|
||||
|
||||
// Determine the method to use.
|
||||
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
|
||||
const T cutoff =
|
||||
T(2) *
|
||||
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||
(normMin + sqrtFactor);
|
||||
const T diff = normMax - normMin;
|
||||
|
||||
// Validate the normalized min and max, because the originals may have been
|
||||
// flipped already.
|
||||
if (!(input_stddev > T(0) && normMin < normMax &&
|
||||
(Eigen::numext::isfinite(normMin) ||
|
||||
Eigen::numext::isfinite(normMax)))) {
|
||||
data[offset] = quietNaN;
|
||||
} else if (diff < cutoff) {
|
||||
// Sample from a uniform distribution on [normMin, normMax].
|
||||
|
||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||
// We always generate at most 4 samples.
|
||||
Eigen::array<T, 4> z;
|
||||
Eigen::array<T, 4> g;
|
||||
|
||||
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||
|
||||
int numIterations = 0;
|
||||
while (numIterations < kMaxIterations) {
|
||||
const auto rand = dist(&gen);
|
||||
remaining_samples -= gen.kResultElementCount;
|
||||
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||
z[i] = rand[i] * diff + normMin;
|
||||
}
|
||||
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
|
||||
}
|
||||
|
||||
const auto u = dist(&gen);
|
||||
remaining_samples -= gen.kResultElementCount;
|
||||
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||
if (u[i] <= Eigen::numext::exp(g[i]) ||
|
||||
numIterations + 1 >= kMaxIterations) {
|
||||
// Accept the sample z.
|
||||
// If we run out of iterations, just use the current uniform
|
||||
// sample. Emperically, the probability of accepting each sample
|
||||
// is at least 50% for typical inputs, so we will always accept
|
||||
// by 100 iterations.
|
||||
// This introduces a slight inaccuracy when at least one bound
|
||||
// is large, minval is negative and maxval is positive.
|
||||
data[offset] = z[i] * stddev + mean;
|
||||
// Break out of the nested loop by updating numIterations.
|
||||
numIterations = kMaxIterations;
|
||||
break;
|
||||
} else {
|
||||
numIterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sample from an exponential distribution with alpha maximizing
|
||||
// acceptance probability, offset by normMin from the origin.
|
||||
// Accept only if less than normMax.
|
||||
const T alpha =
|
||||
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
|
||||
int numIterations = 0;
|
||||
while (numIterations < kMaxIterations) {
|
||||
auto rand = dist(&gen);
|
||||
remaining_samples -= gen.kResultElementCount;
|
||||
UNROLL for (int i = 0; i < kDistSize; i += 2) {
|
||||
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||
const T g = Eigen::numext::exp(-x * x / 2.0);
|
||||
const T u = rand[i + 1];
|
||||
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
|
||||
data[offset] = z * stddev + mean;
|
||||
// Break out of the nested loop by updating numIterations.
|
||||
numIterations = kMaxIterations;
|
||||
break;
|
||||
} else {
|
||||
numIterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gen.Skip(remaining_samples);
|
||||
}
|
||||
}
|
||||
|
||||
// Partial specialization for GPU
|
||||
template <typename T>
|
||||
struct TruncatedNormalFunctor<GPUDevice, T> {
|
||||
static const int kMaxIterations = 100;
|
||||
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
typename TTypes<T>::ConstFlat means,
|
||||
typename TTypes<T>::ConstFlat stddevs,
|
||||
typename TTypes<T>::ConstFlat minvals,
|
||||
typename TTypes<T>::ConstFlat maxvals,
|
||||
const random::PhiloxRandom& gen,
|
||||
typename TTypes<T>::Flat output) {
|
||||
const auto config = GetCudaLaunchConfig(num_elements, d);
|
||||
|
||||
TruncatedNormalKernel<
|
||||
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
gen, output.data(), num_batches, samples_per_batch, num_elements,
|
||||
means.data(), means.dimension(0) == 1, stddevs.data(),
|
||||
stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
|
||||
maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
|
||||
};
|
||||
};
|
||||
|
||||
// Explicit instantiation of the GPU distributions functors
|
||||
template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
|
||||
template struct TruncatedNormalFunctor<GPUDevice, float>;
|
||||
template struct TruncatedNormalFunctor<GPUDevice, double>;
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
|
||||
BM_PTruncatedNormalDev(cpu, 1000, 1000);
|
||||
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
|
||||
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
|
||||
BM_PTruncatedNormalDev(gpu, 1000, 1000);
|
||||
BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
|
||||
BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
402
tensorflow/core/kernels/scatter_nd_op.cc
Normal file
402
tensorflow/core/kernels/scatter_nd_op.cc
Normal file
@ -0,0 +1,402 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/state_ops.cc.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Check whether updates.shape = indices.shape[0] + params.shape[IXDIM:]
|
||||
static bool ValidUpdateShape(const TensorShape& params_shape,
|
||||
const Tensor& indices, const Tensor& updates) {
|
||||
int64 indices_nd = 1;
|
||||
if (indices.dims() > 1) {
|
||||
indices_nd = indices.dim_size(1);
|
||||
}
|
||||
for (int d = indices_nd; d < params_shape.dims(); d++) {
|
||||
if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Index>
|
||||
static void PrepareAndValidateInputs(OpKernelContext* c,
|
||||
const TensorShape& params_shape,
|
||||
const Tensor& indices,
|
||||
const Tensor& updates, int64* indices_nd,
|
||||
Index* num_updates, Index* slice_size) {
|
||||
const TensorShape& indices_shape(indices.shape());
|
||||
const TensorShape& updates_shape(updates.shape());
|
||||
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::IsVectorOrHigher(params_shape),
|
||||
errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ",
|
||||
params_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(c, params_shape.num_elements() >= 0 ||
|
||||
(indices.NumElements() == 0 && updates.NumElements() == 0),
|
||||
errors::InvalidArgument(
|
||||
"Indices and updates specified for empty output", " shape"));
|
||||
|
||||
OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The outermost dimension of updates and indices ",
|
||||
"must match. Got indices.shape ", indices_shape.DebugString(),
|
||||
", updates.shape ", updates_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
c, ValidUpdateShape(params_shape, indices, updates),
|
||||
errors::InvalidArgument(
|
||||
"Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ",
|
||||
"got updates.shape ", updates_shape.DebugString(), ", indices.shape ",
|
||||
indices_shape.DebugString(), ", params_shape ",
|
||||
params_shape.DebugString()));
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements();
|
||||
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument(
|
||||
"indices has too many elements for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
|
||||
N_big, " > ", std::numeric_limits<Index>::max()));
|
||||
OP_REQUIRES(
|
||||
c, params_shape.dim_size(0) <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument("params_shape[0] too large for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||
" indexing: ", params_shape.dim_size(0), " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
|
||||
// Calculate the number of dimensions in indices
|
||||
*indices_nd = 1;
|
||||
if (indices_shape.dims() > 1) {
|
||||
*indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
}
|
||||
|
||||
// Calculate the number of elements that make up each slice of our updated
|
||||
// tensor. This allows us to work with flattened tensors and copy over whole
|
||||
// slices at a time.
|
||||
Index total_nd = params_shape.dims();
|
||||
|
||||
int64 slice_size_big = 1;
|
||||
for (int64 i = *indices_nd; i < total_nd; ++i) {
|
||||
slice_size_big *= params_shape.dim_size(i);
|
||||
}
|
||||
|
||||
OP_REQUIRES(c, slice_size_big <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument("slice size is too large for indexing: ",
|
||||
slice_size_big, " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
|
||||
*slice_size = static_cast<Index>(slice_size_big);
|
||||
|
||||
const int64 safe_indices_nd = (*indices_nd < 1) ? 1 : *indices_nd;
|
||||
*num_updates = indices_shape.num_elements() / safe_indices_nd;
|
||||
}
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
class ScatterNdOp : public OpKernel {
|
||||
public:
|
||||
explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& indices = c->input(0);
|
||||
const Tensor& updates = c->input(1);
|
||||
const Tensor& shape_input = c->input(2);
|
||||
|
||||
OP_REQUIRES(c, shape_input.dims() == 1,
|
||||
errors::InvalidArgument("Shape must be a vector"));
|
||||
auto vec = shape_input.flat<Index>();
|
||||
TensorShape shape;
|
||||
TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape);
|
||||
|
||||
int64 indices_nd;
|
||||
Index num_updates;
|
||||
Index slice_size;
|
||||
PrepareAndValidateInputs<Index>(c, shape, indices, updates, &indices_nd,
|
||||
&num_updates, &slice_size);
|
||||
if (!c->status().ok()) return;
|
||||
|
||||
Tensor scratch;
|
||||
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||
|
||||
auto scratch_scalar = scratch.scalar<Index>();
|
||||
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||
|
||||
Index bad_i = -1;
|
||||
switch (indices_nd) {
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
Tensor* out = nullptr; \
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
|
||||
functor::SetZeroFunctor<Device, T> fill; \
|
||||
fill(c->eigen_device<Device>(), out->flat<T>()); \
|
||||
if (shape.num_elements() > 0) { \
|
||||
auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||
functor::ScatterNdFunctor<Device, T, Index, \
|
||||
scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
|
||||
functor; \
|
||||
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||
output_flat, indices_flat, updates_flat, output_flat); \
|
||||
} \
|
||||
} break
|
||||
PARAMS_CASE(0);
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
PARAMS_CASE(4);
|
||||
PARAMS_CASE(5);
|
||||
#undef PARAMS_CASE
|
||||
default:
|
||||
OP_REQUIRES(c, false,
|
||||
errors::InvalidArgument(
|
||||
"Only indices.shape[-1] values between 0 and 5 "
|
||||
"are currently supported. Requested rank: ",
|
||||
indices_nd));
|
||||
}
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
|
||||
" = [", str_util::Join(gtl::ArraySlice<Index>(
|
||||
&indices_flat(bad_i, 0), indices_nd),
|
||||
", "),
|
||||
"] does not index into ", shape.DebugString()));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename Index,
|
||||
scatter_nd_op::UpdateOp op>
|
||||
class ScatterNdUpdateOp : public OpKernel {
|
||||
public:
|
||||
explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType dt_ref = DataTypeToEnum<T>::ref();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
if (use_exclusive_lock_) {
|
||||
// Hold mutex while we apply updates
|
||||
mutex_lock l(*c->input_ref_mutex(0));
|
||||
DoCompute(c);
|
||||
} else {
|
||||
DoCompute(c);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
|
||||
void DoCompute(OpKernelContext* c) {
|
||||
Tensor params = c->mutable_input(0, use_exclusive_lock_);
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& updates = c->input(2);
|
||||
const TensorShape& params_shape(params.shape());
|
||||
|
||||
int64 indices_nd;
|
||||
Index num_updates;
|
||||
Index slice_size;
|
||||
|
||||
OP_REQUIRES(c, params.IsInitialized(),
|
||||
errors::FailedPrecondition("Null ref for params"));
|
||||
PrepareAndValidateInputs<Index>(c, params_shape, indices, updates,
|
||||
&indices_nd, &num_updates, &slice_size);
|
||||
if (!c->status().ok()) return;
|
||||
|
||||
Tensor scratch;
|
||||
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||
|
||||
auto scratch_scalar = scratch.scalar<Index>();
|
||||
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||
|
||||
Index bad_i = -1;
|
||||
c->forward_ref_input_to_ref_output(0, 0);
|
||||
switch (indices_nd) {
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
|
||||
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||
params_flat, indices_flat, updates_flat, params_flat); \
|
||||
} break
|
||||
PARAMS_CASE(0);
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
PARAMS_CASE(4);
|
||||
PARAMS_CASE(5);
|
||||
#undef PARAMS_CASE
|
||||
default:
|
||||
OP_REQUIRES(c, false,
|
||||
errors::InvalidArgument(
|
||||
"Only indices.shape[-1] values between 1 and 5 "
|
||||
"are currently supported. Requested rank: ",
|
||||
indices_nd));
|
||||
}
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
|
||||
" = [", str_util::Join(gtl::ArraySlice<Index>(
|
||||
&indices_flat(bad_i, 0), indices_nd),
|
||||
", "),
|
||||
"] is not in [0, ", params.dim_size(0), ")"));
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
|
||||
REGISTER_KERNEL_BUILDER(Name(name) \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
ScatterNdOp<dev##Device, type, index_type>)
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
|
||||
op) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(name) \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
ScatterNdUpdateOp<dev##Device, type, index_type, op>)
|
||||
|
||||
#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
|
||||
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
|
||||
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
|
||||
|
||||
#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
|
||||
scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
|
||||
scatter_nd_op::UpdateOp::SUB); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
|
||||
scatter_nd_op::UpdateOp::MUL); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
|
||||
scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
#define REGISTER_SCATTER_ND(type, dev) \
|
||||
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
|
||||
scatter_nd_op::UpdateOp::ASSIGN);
|
||||
|
||||
// Registers CPU kernels.
|
||||
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
|
||||
REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
|
||||
REGISTER_SCATTER_ND_UPDATE(type, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||
|
||||
// Registers GPU kernels.
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
|
||||
REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
|
||||
REGISTER_SCATTER_ND_UPDATE(type, GPU);
|
||||
|
||||
// TODO(simister): Re-enable when GPU support is working.
|
||||
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU);
|
||||
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_SCATTER_ND_ADD
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
|
||||
#undef REGISTER_SCATTER_ND_UPDATE
|
||||
#undef REGISTER_SCATTER_ND_UPDATE_CPU
|
||||
#undef REGISTER_SCATTER_ND_UPDATE_GPU
|
||||
#undef REGISTER_SCATTER_ND_KERNEL
|
||||
#undef REGISTER_SCATTER_ND_KERNEL_INDEX
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
|
||||
#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
|
||||
template <> \
|
||||
Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
|
||||
OpKernelContext* c, const GPUDevice& d, \
|
||||
typename TTypes<T, IXDIM>::Tensor params, \
|
||||
typename TTypes<Index, 2>::ConstTensor indices, \
|
||||
typename TTypes<T, 2>::ConstTensor updates); \
|
||||
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
|
||||
|
||||
#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 5)
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
// TODO(simister): Re-enable when GPU support is working.
|
||||
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
#undef DECLARE_GPU_SPECS_OPS
|
||||
#undef DECLARE_GPU_SPECS_OP
|
||||
|
||||
} // namespace functor
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
62
tensorflow/core/kernels/scatter_nd_op.h
Normal file
62
tensorflow/core/kernels/scatter_nd_op.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
||||
#define TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
namespace scatter_nd_op {
|
||||
|
||||
enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
|
||||
|
||||
} // namespace scatter_nd_op
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Functor used by ScatterOp to do the computations.
|
||||
template <typename Device, typename T, typename Index,
|
||||
scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
struct ScatterNdFunctor {
|
||||
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
|
||||
Index operator()(const Device& d, const Index slice_size,
|
||||
typename TTypes<Index>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
224
tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Normal file
224
tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Normal file
@ -0,0 +1,224 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
||||
|
||||
// Functor definitions for ScatterND ops, must be compilable by nvcc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
// Specialization of UpdateExecutor to CPU
|
||||
namespace generator {
|
||||
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op>
|
||||
class UpdateExecutor {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size);
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> {
|
||||
public:
|
||||
static void Update(T* /* unused */, const T* updates, T* output,
|
||||
Index slice_size) {
|
||||
std::copy_n(updates, slice_size, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output, std::plus<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output, std::minus<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output,
|
||||
std::multiplies<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output,
|
||||
std::divides<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
class ScatterNdSliceGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator(
|
||||
const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput,
|
||||
std::atomic<Index>* error_loc)
|
||||
: slice_size_(slice_size),
|
||||
Tparams_(Tparams),
|
||||
Tindices_(Tindices),
|
||||
Tupdates_(Tupdates),
|
||||
Toutput_(Toutput),
|
||||
error_loc_(error_loc) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC bool GenerateIndices(
|
||||
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
|
||||
(*ix)[IXDIM] = 0;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < IXDIM; ++i) {
|
||||
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
|
||||
(*ix)[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
return out_of_bounds;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||
auto loc = loc_array[0];
|
||||
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params;
|
||||
Eigen::array<Eigen::DenseIndex, 2> ix_updates;
|
||||
ix_updates[0] = loc;
|
||||
ix_updates[1] = 0;
|
||||
const bool out_of_bounds = GenerateIndices(loc, &ix_params);
|
||||
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||
error_loc_->store(loc);
|
||||
} else {
|
||||
UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params),
|
||||
&Tupdates_(ix_updates),
|
||||
&Toutput_(ix_params), slice_size_);
|
||||
}
|
||||
return static_cast<int32>(0); // Return something...
|
||||
}
|
||||
|
||||
protected:
|
||||
const Index slice_size_;
|
||||
mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_;
|
||||
const typename TTypes<Index, 2>::ConstTensor Tindices_;
|
||||
const typename TTypes<T, 2>::ConstTensor Tupdates_;
|
||||
mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_;
|
||||
std::atomic<Index>* error_loc_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Implementation of update functor for CPU.
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
|
||||
Index operator()(const CPUDevice& d, const Index slice_size,
|
||||
typename TTypes<Index>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput) {
|
||||
std::atomic<Index> error_loc(-1);
|
||||
|
||||
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
|
||||
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
|
||||
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
|
||||
broadcast_dims.set(0, batch_size);
|
||||
#endif
|
||||
|
||||
generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator(
|
||||
slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc);
|
||||
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
|
||||
.broadcast(broadcast_dims)
|
||||
.generate(generator)
|
||||
.sum();
|
||||
|
||||
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||
// otherwise it returns the location of an OOB index in Tindices.
|
||||
return error_loc.load();
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
|
||||
template Index \
|
||||
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
|
||||
const CPUDevice& d, const Index slice_size, \
|
||||
typename TTypes<Index>::Scalar Tscratch, \
|
||||
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Tparams, \
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices, \
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates, \
|
||||
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Toutput)
|
||||
|
||||
#define REGISTER_SCATTER_ND_INDEX(type, op) \
|
||||
REGISTER_SCATTER_ND_FULL(type, int32, op); \
|
||||
REGISTER_SCATTER_ND_FULL(type, int64, op)
|
||||
|
||||
#define REGISTER_SCATTER_ND_UPDATE(type) \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MATH(type) \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
|
||||
|
||||
#undef REGISTER_SCATTER_ND_MATH
|
||||
#undef REGISTER_SCATTER_ND_UPDATE
|
||||
#undef REGISTER_SCATTER_ND_INDEX
|
||||
#undef REGISTER_SCATTER_ND_FULL
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 0
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 1
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 2
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 3
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 4
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
19
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
Normal file
19
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 5
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
320
tensorflow/core/kernels/scatter_nd_op_test.cc
Normal file
320
tensorflow/core/kernels/scatter_nd_op_test.cc
Normal file
@ -0,0 +1,320 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ScatterNdUpdateOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
|
||||
.Input(FakeInput(variable_ref_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(RemoveRefType(variable_ref_type)))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
|
||||
MakeOp(DT_STRING_REF, DT_INT32);
|
||||
AddInputFromArray<string>(TensorShape({1}), {"Brain"});
|
||||
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({1}));
|
||||
test::FillValues<string>(&expected, {"TensorFlow"});
|
||||
test::ExpectTensorEqual<string>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
|
||||
MakeOp(DT_BOOL_REF, DT_INT32);
|
||||
AddInputFromArray<bool>(TensorShape({1}), {false});
|
||||
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
AddInputFromArray<bool>(TensorShape({1}), {true});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
|
||||
test::FillValues<bool>(&expected, {true});
|
||||
test::ExpectTensorEqual<bool>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||
10002, 0, 0, 0, 777, 778, 779});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT64);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||
10002, 0, 0, 0, 777, 778, 779});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({0}), {});
|
||||
AddInputFromArray<int32>(TensorShape({0}), {});
|
||||
AddInputFromArray<float>(TensorShape({0}), {});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("Output must not have 0 elements, got shape: "))
|
||||
<< s;
|
||||
}*/
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({1}), {3});
|
||||
AddInputFromArray<float>(TensorShape({1}), {101});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
|
||||
test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
|
||||
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, HigherRank) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
|
||||
AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
|
||||
test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("Invalid indices: [2,0] = [99] is not in [0, 5)"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("The outermost dimension of updates and indices "
|
||||
"must match. Got indices.shape [1,3,1], "
|
||||
"updates.shape [3,3]"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({3, 4}),
|
||||
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("Must have updates.shape = indices.shape[0] + "
|
||||
"params_shape[IXDIM:], got"))
|
||||
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{100, 101, 102, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("The outermost dimension of updates and indices "
|
||||
"must match. Got "))
|
||||
<< s;
|
||||
}
|
||||
|
||||
class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
|
||||
public:
|
||||
virtual void TestBody() {}
|
||||
void MakeBenchmarkOp(const char* op, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", op)
|
||||
.Input(FakeInput(DT_FLOAT_REF))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(node_def()));
|
||||
TF_CHECK_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Index>
|
||||
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
|
||||
testing::StopTiming();
|
||||
const int kRows = 10000000 / embedding_size;
|
||||
std::vector<float> values;
|
||||
values.reserve(kRows);
|
||||
for (int i = 0; i < kRows * embedding_size; i++) {
|
||||
values.push_back(i);
|
||||
}
|
||||
const int kNumUpdates = 1000;
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
std::vector<Index> indices;
|
||||
std::vector<float> updates;
|
||||
for (int i = 0; i < kNumUpdates; i++) {
|
||||
indices.push_back(rnd.Uniform(kRows));
|
||||
for (int j = 0; j < embedding_size; j++) {
|
||||
updates.push_back(i * 10 + j);
|
||||
}
|
||||
}
|
||||
|
||||
ScatterNdUpdateBM bm;
|
||||
bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
|
||||
bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
|
||||
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
|
||||
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
|
||||
updates);
|
||||
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||
iters);
|
||||
testing::StartTiming();
|
||||
while (iters-- > 0) {
|
||||
Status s = bm.RunOpKernel();
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
|
||||
}
|
||||
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
|
||||
}
|
||||
|
||||
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
|
||||
}
|
||||
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ScatterNdUpdateInt32)
|
||||
->Arg(1)
|
||||
->Arg(10)
|
||||
->Arg(64)
|
||||
->Arg(256)
|
||||
->Arg(1024);
|
||||
BENCHMARK(BM_ScatterNdUpdateInt64)
|
||||
->Arg(1)
|
||||
->Arg(10)
|
||||
->Arg(64)
|
||||
->Arg(256)
|
||||
->Arg(1024);
|
||||
|
||||
BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||
BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -88,16 +88,12 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn) {
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
CHECK_GE(total, 0);
|
||||
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
||||
Eigen::ThreadPoolDevice device(this, this->NumThreads());
|
||||
device.parallelFor(
|
||||
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
||||
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
||||
#else
|
||||
CHECK(0); // should not be used with the old thread pool
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -57,7 +57,6 @@ TEST(ThreadPool, DoWork) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
TEST(ThreadPool, ParallelFor) {
|
||||
// Make ParallelFor use as many threads as possible.
|
||||
int64 kHugeCost = 1 << 30;
|
||||
@ -80,7 +79,6 @@ TEST(ThreadPool, ParallelFor) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static void BM_Sequential(int iters) {
|
||||
ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -49,11 +50,14 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return the parts of the path, split on the final "/". If there is no
|
||||
// "/" in the path, the first part of the output is empty and the second
|
||||
// is the input. If the only "/" in the path is the first character, it is
|
||||
// the first part of the output.
|
||||
std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
|
||||
// Return the parts of the URI, split on the final "/" in the path. If there is
|
||||
// no "/" in the path, the first part of the output is the scheme and host, and
|
||||
// the second is the path. If the only "/" in the path is the first character,
|
||||
// it is included in the first part of the output.
|
||||
std::pair<StringPiece, StringPiece> SplitPath(StringPiece uri) {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(uri, &scheme, &host, &path);
|
||||
|
||||
auto pos = path.rfind('/');
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
if (pos == StringPiece::npos)
|
||||
@ -61,15 +65,17 @@ std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
|
||||
#endif
|
||||
// Handle the case with no '/' in 'path'.
|
||||
if (pos == StringPiece::npos)
|
||||
return std::make_pair(StringPiece(path.data(), 0), path);
|
||||
return std::make_pair(StringPiece(uri.begin(), host.end() - uri.begin()),
|
||||
path);
|
||||
|
||||
// Handle the case with a single leading '/' in 'path'.
|
||||
if (pos == 0)
|
||||
return std::make_pair(StringPiece(path.data(), 1),
|
||||
StringPiece(path.data() + 1, path.size() - 1));
|
||||
return std::make_pair(
|
||||
StringPiece(uri.begin(), path.begin() + 1 - uri.begin()),
|
||||
StringPiece(path.data() + 1, path.size() - 1));
|
||||
|
||||
return std::make_pair(
|
||||
StringPiece(path.data(), pos),
|
||||
StringPiece(uri.begin(), path.begin() + pos - uri.begin()),
|
||||
StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
|
||||
}
|
||||
|
||||
@ -185,5 +191,42 @@ string CleanPath(StringPiece unclean_path) {
|
||||
return path;
|
||||
}
|
||||
|
||||
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
|
||||
StringPiece* path) {
|
||||
// 0. Parse scheme
|
||||
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
|
||||
// TODO(keveman): Allow "+" and "-" in the scheme.
|
||||
if (!strings::Scanner(remaining)
|
||||
.One(strings::Scanner::LETTER)
|
||||
.Many(strings::Scanner::LETTER_DIGIT_DOT)
|
||||
.StopCapture()
|
||||
.OneLiteral("://")
|
||||
.GetResult(&remaining, scheme)) {
|
||||
// If there's no scheme, assume the entire string is a path.
|
||||
*scheme = StringPiece(remaining.begin(), 0);
|
||||
*host = StringPiece(remaining.begin(), 0);
|
||||
*path = remaining;
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Parse host
|
||||
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
|
||||
// No path, so the rest of the URI is the host.
|
||||
*host = remaining;
|
||||
*path = StringPiece(remaining.end(), 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// 2. The rest is the path
|
||||
*path = remaining;
|
||||
}
|
||||
|
||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
|
||||
if (scheme.empty()) {
|
||||
return path.ToString();
|
||||
}
|
||||
return strings::StrCat(scheme, "://", host, path);
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
} // namespace tensorflow
|
||||
|
@ -74,6 +74,21 @@ StringPiece Extension(StringPiece path);
|
||||
// string manipulation, completely independent of process state.
|
||||
string CleanPath(StringPiece path);
|
||||
|
||||
// Populates the scheme, host, and path from a URI. scheme, host, and path are
|
||||
// guaranteed by this function to point into the contents of uri, even if
|
||||
// empty.
|
||||
//
|
||||
// Corner cases:
|
||||
// - If the URI is invalid, scheme and host are set to empty strings and the
|
||||
// passed string is assumed to be a path
|
||||
// - If the URI omits the path (e.g. file://host), then the path is left empty.
|
||||
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
|
||||
StringPiece* path);
|
||||
|
||||
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
|
||||
// return the path.
|
||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
|
||||
|
||||
} // namespace io
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -45,6 +45,8 @@ TEST(PathTest, IsAbsolutePath) {
|
||||
}
|
||||
|
||||
TEST(PathTest, Dirname) {
|
||||
EXPECT_EQ("hdfs://127.0.0.1:9000/",
|
||||
Dirname("hdfs://127.0.0.1:9000/train.csv.tfrecords"));
|
||||
EXPECT_EQ("/hello", Dirname("/hello/"));
|
||||
EXPECT_EQ("/", Dirname("/hello"));
|
||||
EXPECT_EQ("hello", Dirname("hello/world"));
|
||||
@ -97,5 +99,47 @@ TEST(PathTest, CleanPath) {
|
||||
EXPECT_EQ("../../bar", CleanPath("foo/../../../bar"));
|
||||
}
|
||||
|
||||
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
|
||||
do { \
|
||||
StringPiece u(uri); \
|
||||
StringPiece s, h, p; \
|
||||
ParseURI(u, &s, &h, &p); \
|
||||
EXPECT_EQ(scheme, s.ToString()); \
|
||||
EXPECT_EQ(host, h.ToString()); \
|
||||
EXPECT_EQ(path, p.ToString()); \
|
||||
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
|
||||
EXPECT_LE(u.begin(), s.begin()); \
|
||||
EXPECT_GE(u.end(), s.begin()); \
|
||||
EXPECT_LE(u.begin(), s.end()); \
|
||||
EXPECT_GE(u.end(), s.end()); \
|
||||
EXPECT_LE(u.begin(), h.begin()); \
|
||||
EXPECT_GE(u.end(), h.begin()); \
|
||||
EXPECT_LE(u.begin(), h.end()); \
|
||||
EXPECT_GE(u.end(), h.end()); \
|
||||
EXPECT_LE(u.begin(), p.begin()); \
|
||||
EXPECT_GE(u.end(), p.begin()); \
|
||||
EXPECT_LE(u.begin(), p.end()); \
|
||||
EXPECT_GE(u.end(), p.end()); \
|
||||
} while (0)
|
||||
|
||||
TEST(PathTest, CreateParseURI) {
|
||||
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
|
||||
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
|
||||
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
|
||||
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
|
||||
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
|
||||
"/usr/local/foo");
|
||||
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
|
||||
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
|
||||
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
|
||||
EXPECT_PARSE_URI("file:", "", "", "file:");
|
||||
EXPECT_PARSE_URI("file:/", "", "", "file:/");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
|
||||
"localhost:8020", "/path/to/file");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
|
||||
}
|
||||
#undef EXPECT_PARSE_URI
|
||||
|
||||
} // namespace io
|
||||
} // namespace tensorflow
|
||||
|
@ -4387,6 +4387,83 @@ output_min: This value is copied from input_min.
|
||||
output_max: This value is copied from input_max.
|
||||
)Doc");
|
||||
|
||||
REGISTER_OP("ScatterNd")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Input("shape: Tindices")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Doc(
|
||||
R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices.
|
||||
This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.
|
||||
|
||||
TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.
|
||||
|
||||
`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `shape`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `shape`.
|
||||
|
||||
`updates` is Tensor of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].
|
||||
```
|
||||
|
||||
The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/ScatterNd1.png" alt>
|
||||
</div>
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
shape = tf.constant([8])
|
||||
scatter = tf.scatter_nd(indices, updates, shape)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(scatter)
|
||||
|
||||
The resulting tensor would look like this:
|
||||
|
||||
[0, 11, 0, 10, 9, 0, 0, 12]
|
||||
|
||||
We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/ScatterNd2.png" alt>
|
||||
</div>
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
indices = tf.constant([[0], [2]])
|
||||
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]]])
|
||||
shape = tf.constant([4, 4, 4])
|
||||
scatter = tf.scatter_nd(indices, updates, shape)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(scatter)
|
||||
|
||||
The resulting tensor would look like this:
|
||||
|
||||
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref.
|
||||
shape: A vector. The shape of the resulting tensor.
|
||||
output: A new tensor with the given shape and updates applied according to the indices.)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
||||
.Attr("min: float = -6.0")
|
||||
.Attr("max: float = 6.0")
|
||||
@ -4409,6 +4486,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
|
||||
.Input("gradients: float")
|
||||
.Input("inputs: float")
|
||||
.Output("backprops: float")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxArgs operation.
|
||||
|
||||
@ -4450,6 +4528,21 @@ REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
|
||||
.Output("backprops_wrt_input: float")
|
||||
.Output("backprop_wrt_min: float")
|
||||
.Output("backprop_wrt_max: float")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// gradients and inputs are same size.
|
||||
ShapeHandle inputs;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
|
||||
|
||||
// min and max are scalars
|
||||
ShapeHandle min_max;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
|
||||
TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
|
||||
|
||||
c->set_output(0, inputs);
|
||||
c->set_output(1, min_max);
|
||||
c->set_output(2, min_max);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxVars operation.
|
||||
|
||||
@ -4503,6 +4596,24 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Output("backprops_wrt_input: float")
|
||||
.Output("backprop_wrt_min: float")
|
||||
.Output("backprop_wrt_max: float")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle inputs;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
|
||||
TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
|
||||
|
||||
ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
|
||||
|
||||
ShapeHandle min_max;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
|
||||
TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
|
||||
|
||||
c->set_output(0, inputs);
|
||||
c->set_output(1, min_max);
|
||||
c->set_output(2, min_max);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
|
||||
|
||||
|
@ -1533,4 +1533,23 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
|
||||
INFER_ERROR("must be equal", op, "[5];[4];[?]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
|
||||
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
|
||||
|
||||
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
|
||||
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
|
||||
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
|
||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
|
||||
|
||||
// Rank check vectors.
|
||||
INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
|
||||
INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
|
||||
INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
|
||||
INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
|
||||
|
||||
// Vectors must match each other, and match last dim of input.
|
||||
INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
|
||||
INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -24732,6 +24732,321 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNd"
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "shape"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdAdd"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdDiv"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdMul"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdSub"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdUpdate"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterSub"
|
||||
input_arg {
|
||||
|
@ -15427,6 +15427,362 @@ op {
|
||||
summary: "Multiplies sparse updates into a variable reference."
|
||||
description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] *= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] *= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions multiply.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNd"
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "shape"
|
||||
description: "A vector. The shape of the resulting tensor."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "A new tensor with the given shape and updates applied according to the indices."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices."
|
||||
description: "This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.\n\nTODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.\n\n`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `shape`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `shape`.\n\n`updates` is Tensor of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].\n```\n\nThe simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd1.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n shape = tf.constant([8])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [0, 11, 0, 10, 9, 0, 0, 12]\n\nWe can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd2.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[0], [2]])\n updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]],\n [[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]]])\n shape = tf.constant([4, 4, 4])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]"
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdAdd"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n add = tf.scatter_nd_add(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(add)\n\nThe resulting update to ref would look like this:\n\n [1, 13, 3, 14, 14, 6, 7, 20]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdDiv"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:\n\n ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([2, 3, 4, 5])\n sub = tf.scatter_nd_div(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [10, 5, 30, 13, 25, 60, 70, 16]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdMul"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_mul(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, 22, 3, 40, 45, 6, 7, 96]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdSub"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_sub(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, -9, 3, -6, -4, 6, 7, -4]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdUpdate"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse `updates` to individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1] ,[7]])\n updates = tf.constant([9, 10, 11, 12])\n update = tf.scatter_nd_update(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(update)\n\nThe resulting update to ref would look like this:\n\n [1, 11, 3, 10, 9, 6, 7, 12]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterSub"
|
||||
input_arg {
|
||||
|
@ -445,6 +445,241 @@ use_locking: If True, the operation will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdUpdate")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Doc(
|
||||
R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1] ,[7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
update = tf.scatter_nd_update(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(update)
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
|
||||
[1, 11, 3, 10, 9, 6, 7, 12]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdAdd")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
add = tf.scatter_nd_add(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(add)
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
|
||||
[1, 13, 3, 14, 14, 6, 7, 20]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdSub")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
sub = tf.scatter_nd_sub(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(sub)
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
|
||||
[1, -9, 3, -6, -4, 6, 7, -4]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdMul")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
sub = tf.scatter_nd_mul(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(sub)
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
|
||||
[1, 22, 3, 40, 45, 6, 7, 96]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdDiv")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:
|
||||
|
||||
ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([2, 3, 4, 5])
|
||||
sub = tf.scatter_nd_div(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(sub)
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
|
||||
[10, 5, 30, 13, 25, 60, 70, 16]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("CountUpTo")
|
||||
.Input("ref: Ref(T)")
|
||||
.Output("output: T")
|
||||
|
@ -81,7 +81,7 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
|
||||
return errors::Internal("bucket and object cannot be null.");
|
||||
}
|
||||
StringPiece scheme, bucketp, objectp;
|
||||
ParseURI(fname, &scheme, &bucketp, &objectp);
|
||||
io::ParseURI(fname, &scheme, &bucketp, &objectp);
|
||||
if (scheme != "gs") {
|
||||
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
|
||||
fname);
|
||||
|
@ -76,16 +76,24 @@ cc_library(
|
||||
name = "platformlib",
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":gif",
|
||||
":jpeg",
|
||||
"//tensorflow/core:protos_cc",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@gif_archive//:gif",
|
||||
"@highwayhash//:sip_hash",
|
||||
"@jpeg_archive//:jpeg",
|
||||
"@png_archive//:png",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gif",
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"@gif_archive//:gif",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
copts = tf_copts(),
|
||||
|
@ -70,7 +70,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
|
||||
|
||||
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(fname, &scheme, &host, &path);
|
||||
io::ParseURI(fname, &scheme, &host, &path);
|
||||
FileSystem* file_system = file_system_registry_->Lookup(scheme.ToString());
|
||||
if (!file_system) {
|
||||
return errors::Unimplemented("File system scheme ", scheme,
|
||||
|
@ -229,35 +229,6 @@ TEST_F(DefaultEnvTest, LocalFileSystem) {
|
||||
}
|
||||
}
|
||||
|
||||
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
|
||||
do { \
|
||||
StringPiece s, h, p; \
|
||||
ParseURI(uri, &s, &h, &p); \
|
||||
EXPECT_EQ(scheme, s.ToString()); \
|
||||
EXPECT_EQ(host, h.ToString()); \
|
||||
EXPECT_EQ(path, p.ToString()); \
|
||||
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
|
||||
} while (0)
|
||||
|
||||
TEST_F(DefaultEnvTest, CreateParseURI) {
|
||||
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
|
||||
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
|
||||
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
|
||||
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
|
||||
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
|
||||
"/usr/local/foo");
|
||||
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
|
||||
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
|
||||
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
|
||||
EXPECT_PARSE_URI("file:", "", "", "file:");
|
||||
EXPECT_PARSE_URI("file:/", "", "", "file:/");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
|
||||
"localhost:8020", "/path/to/file");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
|
||||
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
|
||||
}
|
||||
#undef EXPECT_PARSE_URI
|
||||
|
||||
TEST_F(DefaultEnvTest, SleepForMicroseconds) {
|
||||
const int64 start = env_->NowMicros();
|
||||
const int64 sleep_time = 1e6 + 5e5;
|
||||
@ -274,14 +245,14 @@ class TmpDirFileSystem : public NullFileSystem {
|
||||
public:
|
||||
bool FileExists(const string& dir) override {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(dir, &scheme, &host, &path);
|
||||
io::ParseURI(dir, &scheme, &host, &path);
|
||||
if (path.empty()) return false;
|
||||
return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
|
||||
}
|
||||
|
||||
Status CreateDir(const string& dir) override {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(dir, &scheme, &host, &path);
|
||||
io::ParseURI(dir, &scheme, &host, &path);
|
||||
if (scheme != "tmpdirfs") {
|
||||
return errors::FailedPrecondition("scheme must be tmpdirfs");
|
||||
}
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
@ -79,43 +78,6 @@ WritableFile::~WritableFile() {}
|
||||
|
||||
FileSystemRegistry::~FileSystemRegistry() {}
|
||||
|
||||
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
|
||||
StringPiece* path) {
|
||||
// 0. Parse scheme
|
||||
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
|
||||
// TODO(keveman): Allow "+" and "-" in the scheme.
|
||||
if (!strings::Scanner(remaining)
|
||||
.One(strings::Scanner::LETTER)
|
||||
.Many(strings::Scanner::LETTER_DIGIT_DOT)
|
||||
.StopCapture()
|
||||
.OneLiteral("://")
|
||||
.GetResult(&remaining, scheme)) {
|
||||
// If there's no scheme, assume the entire string is a path.
|
||||
scheme->clear();
|
||||
host->clear();
|
||||
*path = remaining;
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Parse host
|
||||
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
|
||||
// No path, so the rest of the URI is the host.
|
||||
*host = remaining;
|
||||
path->clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// 2. The rest is the path
|
||||
*path = remaining;
|
||||
}
|
||||
|
||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
|
||||
if (scheme.empty()) {
|
||||
return path.ToString();
|
||||
}
|
||||
return strings::StrCat(scheme, "://", host, path);
|
||||
}
|
||||
|
||||
Status FileSystem::GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) {
|
||||
results->clear();
|
||||
@ -237,9 +199,9 @@ Status FileSystem::DeleteRecursively(const string& dirname,
|
||||
|
||||
Status FileSystem::RecursivelyCreateDir(const string& dirname) {
|
||||
StringPiece scheme, host, remaining_dir;
|
||||
ParseURI(dirname, &scheme, &host, &remaining_dir);
|
||||
io::ParseURI(dirname, &scheme, &host, &remaining_dir);
|
||||
std::vector<StringPiece> sub_dirs;
|
||||
while (!FileExists(CreateURI(scheme, host, remaining_dir)) &&
|
||||
while (!FileExists(io::CreateURI(scheme, host, remaining_dir)) &&
|
||||
!remaining_dir.empty()) {
|
||||
// Basename returns "" for / ending dirs.
|
||||
if (!remaining_dir.ends_with("/")) {
|
||||
@ -255,7 +217,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
|
||||
string built_path = remaining_dir.ToString();
|
||||
for (const StringPiece sub_dir : sub_dirs) {
|
||||
built_path = io::JoinPath(built_path, sub_dir);
|
||||
TF_RETURN_IF_ERROR(CreateDir(CreateURI(scheme, host, built_path)));
|
||||
TF_RETURN_IF_ERROR(CreateDir(io::CreateURI(scheme, host, built_path)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -287,19 +287,6 @@ class FileSystemRegistry {
|
||||
std::vector<string>* schemes) = 0;
|
||||
};
|
||||
|
||||
// Populates the scheme, host, and path from a URI.
|
||||
//
|
||||
// Corner cases:
|
||||
// - If the URI is invalid, scheme and host are set to empty strings and the
|
||||
// passed string is assumed to be a path
|
||||
// - If the URI omits the path (e.g. file://host), then the path is left empty.
|
||||
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
|
||||
StringPiece* path);
|
||||
|
||||
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
|
||||
// return the path.
|
||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_
|
||||
|
@ -112,7 +112,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
|
||||
|
||||
void ParsePath(const string& name, string* parsed_path) {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(name, &scheme, &host, &path);
|
||||
io::ParseURI(name, &scheme, &host, &path);
|
||||
ASSERT_EQ(scheme, "ipfs");
|
||||
ASSERT_EQ(host, "solarsystem");
|
||||
path.Consume("/");
|
||||
|
@ -126,7 +126,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
|
||||
TF_RETURN_IF_ERROR(hdfs_->status());
|
||||
|
||||
StringPiece scheme, namenode, path;
|
||||
ParseURI(fname, &scheme, &namenode, &path);
|
||||
io::ParseURI(fname, &scheme, &namenode, &path);
|
||||
const string nn = namenode.ToString();
|
||||
|
||||
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
|
||||
@ -144,7 +144,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
|
||||
|
||||
string HadoopFileSystem::TranslateName(const string& name) const {
|
||||
StringPiece scheme, namenode, path;
|
||||
ParseURI(name, &scheme, &namenode, &path);
|
||||
io::ParseURI(name, &scheme, &namenode, &path);
|
||||
return path.ToString();
|
||||
}
|
||||
|
||||
|
@ -120,7 +120,8 @@ class PosixEnv : public Env {
|
||||
symbol);
|
||||
}
|
||||
|
||||
string FormatLibraryFileName(const string& name, const string& version) {
|
||||
string FormatLibraryFileName(const string& name,
|
||||
const string& version) override {
|
||||
return tensorflow::internal::FormatLibraryFileName(name, version);
|
||||
}
|
||||
};
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
||||
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -63,7 +64,7 @@ class LocalPosixFileSystem : public PosixFileSystem {
|
||||
public:
|
||||
string TranslateName(const string& name) const override {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(name, &scheme, &host, &path);
|
||||
io::ParseURI(name, &scheme, &host, &path);
|
||||
return path.ToString();
|
||||
}
|
||||
};
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
||||
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
@ -68,7 +69,7 @@ class LocalWinFileSystem : public WindowsFileSystem {
|
||||
public:
|
||||
string TranslateName(const string& name) const override {
|
||||
StringPiece scheme, host, path;
|
||||
ParseURI(name, &scheme, &host, &path);
|
||||
io::ParseURI(name, &scheme, &host, &path);
|
||||
return path.ToString();
|
||||
}
|
||||
};
|
||||
|
@ -122,6 +122,10 @@ message RunStepRequest {
|
||||
|
||||
// Options for the run call.
|
||||
RunOptions options = 5;
|
||||
|
||||
// Partial run handle (optional). If specified, this will be a partial run
|
||||
// execution, run up to the specified fetches.
|
||||
string partial_run_handle = 6;
|
||||
}
|
||||
|
||||
message RunStepResponse {
|
||||
@ -133,6 +137,42 @@ message RunStepResponse {
|
||||
RunMetadata metadata = 2;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// PartialRunSetup method request/response protos.
|
||||
//
|
||||
// The caller should provide the future partial run feeds, fetches, and targets.
|
||||
// Then the caller can use RunStepRequest with is_partial set to make partial
|
||||
// run calls.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message PartialRunSetupRequest {
|
||||
// REQUIRED: session_handle must be returned by a CreateSession call
|
||||
// to the same master service.
|
||||
string session_handle = 1;
|
||||
|
||||
// Tensors to be fed in future steps.
|
||||
repeated string feed = 2;
|
||||
|
||||
// Fetches. A list of tensor names. The caller expects a tensor to be returned
|
||||
// for each fetch[i] (see RunStepResponse.tensor), for corresponding partial
|
||||
// RunStepRequests. The order of specified fetches does not change the
|
||||
// execution order.
|
||||
repeated string fetch = 3;
|
||||
|
||||
// Target Nodes. A list of node names. The named nodes will be run in future
|
||||
// steps, but their outputs will not be fetched.
|
||||
repeated string target = 4;
|
||||
}
|
||||
|
||||
message PartialRunSetupResponse {
|
||||
// The unique handle corresponding to the ongoing partial run call setup by
|
||||
// the invocation to PartialRunSetup. This handle may be passed to
|
||||
// RunStepRequest to send and receive tensors for this partial run.
|
||||
string partial_run_handle = 1;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CloseSession method request/response protos.
|
||||
|
@ -91,6 +91,9 @@ service MasterService {
|
||||
// Extends a session.
|
||||
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
|
||||
|
||||
// Prepares future partial run calls.
|
||||
rpc PartialRunSetup(PartialRunSetupRequest) returns (PartialRunSetupResponse);
|
||||
|
||||
// Drives the graph computation.
|
||||
rpc RunStep(RunStepRequest) returns (RunStepResponse);
|
||||
|
||||
|
@ -343,7 +343,11 @@ Status BundleWriter::Finish() {
|
||||
status_ = env_->NewWritableFile(MetaFilename(prefix_), &file);
|
||||
if (!status_.ok()) return status_;
|
||||
{
|
||||
table::TableBuilder builder(table::Options(), file.get());
|
||||
// N.B.: the default use of Snappy compression may not be supported on all
|
||||
// platforms (e.g. Android). The metadata file is small, so this is fine.
|
||||
table::Options options;
|
||||
options.compression = table::kNoCompression;
|
||||
table::TableBuilder builder(options, file.get());
|
||||
// Header entry.
|
||||
BundleHeaderProto header;
|
||||
header.set_num_shards(1);
|
||||
|
@ -31,12 +31,10 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
|
||||
work(0, total);
|
||||
return;
|
||||
}
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
if (max_parallelism >= workers->NumThreads()) {
|
||||
workers->ParallelFor(total, cost_per_unit, work);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cost_per_unit = std::max(1LL, cost_per_unit);
|
||||
// We shard [0, total) into "num_shards" shards.
|
||||
// 1 <= num_shards <= num worker threads
|
||||
|
@ -1,34 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent">
|
||||
|
||||
<org.tensorflow.demo.AutoFitTextureView
|
||||
android:id="@+id/texture"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_alignParentBottom="true"
|
||||
android:layout_alignParentStart="true"
|
||||
android:layout_alignParentTop="true" />
|
||||
|
||||
<org.tensorflow.demo.RecognitionScoreView
|
||||
android:id="@+id/results"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="112dp"
|
||||
android:layout_alignParentTop="true" />
|
||||
|
||||
</RelativeLayout>
|
@ -22,11 +22,17 @@
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_alignParentBottom="true" />
|
||||
|
||||
|
||||
<org.tensorflow.demo.RecognitionScoreView
|
||||
android:id="@+id/results"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="112dp"
|
||||
android:layout_alignParentTop="true" />
|
||||
|
||||
|
||||
<org.tensorflow.demo.OverlayView
|
||||
android:id="@+id/overlay"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:layout_alignParentBottom="true" />
|
||||
|
||||
</RelativeLayout>
|
||||
|
@ -20,32 +20,72 @@ import android.Manifest;
|
||||
import android.app.Activity;
|
||||
import android.app.Fragment;
|
||||
import android.content.pm.PackageManager;
|
||||
import android.media.Image.Plane;
|
||||
import android.media.ImageReader.OnImageAvailableListener;
|
||||
import android.os.Build;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
import android.util.Size;
|
||||
import android.view.MotionEvent;
|
||||
import android.view.WindowManager;
|
||||
import android.widget.Toast;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.tensorflow.demo.env.Logger;
|
||||
|
||||
public abstract class CameraActivity extends Activity implements OnImageAvailableListener {
|
||||
private static final Logger LOGGER = new Logger();
|
||||
|
||||
public abstract class CameraActivity extends Activity {
|
||||
private static final int PERMISSIONS_REQUEST = 1;
|
||||
|
||||
private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
|
||||
private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
|
||||
|
||||
private boolean debug = false;
|
||||
|
||||
private Handler handler;
|
||||
private HandlerThread handlerThread;
|
||||
|
||||
@Override
|
||||
protected void onCreate(final Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
super.onCreate(null);
|
||||
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
|
||||
|
||||
setContentView(R.layout.activity_camera);
|
||||
|
||||
if (hasPermission()) {
|
||||
if (null == savedInstanceState) {
|
||||
setFragment();
|
||||
}
|
||||
setFragment();
|
||||
} else {
|
||||
requestPermission();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void onResume() {
|
||||
super.onResume();
|
||||
|
||||
handlerThread = new HandlerThread("inference");
|
||||
handlerThread.start();
|
||||
handler = new Handler(handlerThread.getLooper());
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void onPause() {
|
||||
super.onPause();
|
||||
handlerThread.quitSafely();
|
||||
try {
|
||||
handlerThread.join();
|
||||
handlerThread = null;
|
||||
handler = null;
|
||||
} catch (final InterruptedException e) {
|
||||
LOGGER.e(e, "Exception!");
|
||||
}
|
||||
}
|
||||
|
||||
protected synchronized void runInBackground(final Runnable r) {
|
||||
if (handler != null) {
|
||||
handler.post(r);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -82,11 +122,47 @@ public abstract class CameraActivity extends Activity {
|
||||
}
|
||||
|
||||
protected void setFragment() {
|
||||
final Fragment fragment = CameraConnectionFragment.newInstance(
|
||||
new CameraConnectionFragment.ConnectionCallback(){
|
||||
@Override
|
||||
public void onPreviewSizeChosen(final Size size, final int rotation) {
|
||||
CameraActivity.this.onPreviewSizeChosen(size, rotation);
|
||||
}
|
||||
},
|
||||
this, getLayoutId(), getDesiredPreviewFrameSize());
|
||||
|
||||
getFragmentManager()
|
||||
.beginTransaction()
|
||||
.replace(R.id.container, createFragment())
|
||||
.replace(R.id.container, fragment)
|
||||
.commit();
|
||||
}
|
||||
|
||||
protected abstract Fragment createFragment();
|
||||
protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
|
||||
// Because of the variable row stride it's not possible to know in
|
||||
// advance the actual necessary dimensions of the yuv planes.
|
||||
for (int i = 0; i < planes.length; ++i) {
|
||||
final ByteBuffer buffer = planes[i].getBuffer();
|
||||
if (yuvBytes[i] == null) {
|
||||
LOGGER.i("Initializing buffer %d at size %d", i, buffer.capacity());
|
||||
yuvBytes[i] = new byte[buffer.capacity()];
|
||||
}
|
||||
buffer.get(yuvBytes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean onTouchEvent(final MotionEvent event) {
|
||||
if (event.getAction() == MotionEvent.ACTION_DOWN) {
|
||||
debug = !debug;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean isDebug() {
|
||||
return debug;
|
||||
}
|
||||
|
||||
protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
|
||||
protected abstract int getLayoutId();
|
||||
protected abstract int getDesiredPreviewFrameSize();
|
||||
}
|
||||
|
@ -38,6 +38,7 @@ import android.hardware.camera2.CaptureResult;
|
||||
import android.hardware.camera2.TotalCaptureResult;
|
||||
import android.hardware.camera2.params.StreamConfigurationMap;
|
||||
import android.media.ImageReader;
|
||||
import android.media.ImageReader.OnImageAvailableListener;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
@ -49,9 +50,6 @@ import android.view.TextureView;
|
||||
import android.view.View;
|
||||
import android.view.ViewGroup;
|
||||
import android.widget.Toast;
|
||||
|
||||
import org.tensorflow.demo.env.Logger;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
@ -59,6 +57,7 @@ import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Semaphore;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.tensorflow.demo.env.Logger;
|
||||
|
||||
public class CameraConnectionFragment extends Fragment {
|
||||
private static final Logger LOGGER = new Logger();
|
||||
@ -69,8 +68,6 @@ public class CameraConnectionFragment extends Fragment {
|
||||
*/
|
||||
private static final int MINIMUM_PREVIEW_SIZE = 320;
|
||||
|
||||
private ResultsView resultsView;
|
||||
|
||||
/**
|
||||
* Conversion from screen rotation to JPEG orientation.
|
||||
*/
|
||||
@ -111,6 +108,14 @@ public class CameraConnectionFragment extends Fragment {
|
||||
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Callback for Activities to use to initialize their data once the
|
||||
* selected preview size is known.
|
||||
*/
|
||||
public interface ConnectionCallback {
|
||||
void onPreviewSizeChosen(Size size, int cameraRotation);
|
||||
}
|
||||
|
||||
/**
|
||||
* ID of the current {@link CameraDevice}.
|
||||
*/
|
||||
@ -184,16 +189,6 @@ public class CameraConnectionFragment extends Fragment {
|
||||
*/
|
||||
private Handler backgroundHandler;
|
||||
|
||||
/**
|
||||
* An additional thread for running inference so as not to block the camera.
|
||||
*/
|
||||
private HandlerThread inferenceThread;
|
||||
|
||||
/**
|
||||
* A {@link Handler} for running tasks in the background.
|
||||
*/
|
||||
private Handler inferenceHandler;
|
||||
|
||||
/**
|
||||
* An {@link ImageReader} that handles preview frame capture.
|
||||
*/
|
||||
@ -215,9 +210,10 @@ public class CameraConnectionFragment extends Fragment {
|
||||
private final Semaphore cameraOpenCloseLock = new Semaphore(1);
|
||||
|
||||
/**
|
||||
* A {@link Classifier} object wrapping TensorFlow to pass frames to.
|
||||
* A {@link OnImageAvailableListener} to receive frames as they are available.
|
||||
*/
|
||||
private final Classifier classifier;
|
||||
private final OnImageAvailableListener imageListener;
|
||||
|
||||
/**
|
||||
* The input size in pixels desired by TensorFlow (width and height of a square bitmap).
|
||||
*/
|
||||
@ -228,9 +224,15 @@ public class CameraConnectionFragment extends Fragment {
|
||||
*/
|
||||
private final int layout;
|
||||
|
||||
|
||||
private final ConnectionCallback cameraConnectionCallback;
|
||||
|
||||
private CameraConnectionFragment(
|
||||
final Classifier classifier, final int layout, final int inputSize) {
|
||||
this.classifier = classifier;
|
||||
final ConnectionCallback connectionCallback,
|
||||
final OnImageAvailableListener imageListener,
|
||||
final int layout, final int inputSize) {
|
||||
this.cameraConnectionCallback = connectionCallback;
|
||||
this.imageListener = imageListener;
|
||||
this.layout = layout;
|
||||
this.inputSize = inputSize;
|
||||
}
|
||||
@ -268,8 +270,12 @@ public class CameraConnectionFragment extends Fragment {
|
||||
final Size[] choices, final int width, final int height, final Size aspectRatio) {
|
||||
// Collect the supported resolutions that are at least as big as the preview Surface
|
||||
final List<Size> bigEnough = new ArrayList<Size>();
|
||||
|
||||
final int minWidth = Math.max(width, MINIMUM_PREVIEW_SIZE);
|
||||
final int minHeight = Math.max(height, MINIMUM_PREVIEW_SIZE);
|
||||
|
||||
for (final Size option : choices) {
|
||||
if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) {
|
||||
if (option.getHeight() >= minHeight && option.getWidth() >= minWidth) {
|
||||
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
|
||||
bigEnough.add(option);
|
||||
} else {
|
||||
@ -289,8 +295,9 @@ public class CameraConnectionFragment extends Fragment {
|
||||
}
|
||||
|
||||
public static CameraConnectionFragment newInstance(
|
||||
final Classifier classifier, final int layout, final int inputSize) {
|
||||
return new CameraConnectionFragment(classifier, layout, inputSize);
|
||||
final ConnectionCallback callback,
|
||||
final OnImageAvailableListener imageListener, final int layout, final int inputSize) {
|
||||
return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -302,7 +309,6 @@ public class CameraConnectionFragment extends Fragment {
|
||||
@Override
|
||||
public void onViewCreated(final View view, final Bundle savedInstanceState) {
|
||||
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
|
||||
resultsView = (ResultsView) view.findViewById(R.id.results);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -371,7 +377,8 @@ public class CameraConnectionFragment extends Fragment {
|
||||
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
|
||||
// garbage capture data.
|
||||
previewSize =
|
||||
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height, largest);
|
||||
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
|
||||
inputSize, inputSize, largest);
|
||||
|
||||
// We fit the aspect ratio of TextureView to the size of preview we picked.
|
||||
final int orientation = getResources().getConfiguration().orientation;
|
||||
@ -382,6 +389,8 @@ public class CameraConnectionFragment extends Fragment {
|
||||
}
|
||||
|
||||
CameraConnectionFragment.this.cameraId = cameraId;
|
||||
|
||||
cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
|
||||
return;
|
||||
}
|
||||
} catch (final CameraAccessException e) {
|
||||
@ -446,10 +455,6 @@ public class CameraConnectionFragment extends Fragment {
|
||||
backgroundThread = new HandlerThread("ImageListener");
|
||||
backgroundThread.start();
|
||||
backgroundHandler = new Handler(backgroundThread.getLooper());
|
||||
|
||||
inferenceThread = new HandlerThread("InferenceThread");
|
||||
inferenceThread.start();
|
||||
inferenceHandler = new Handler(inferenceThread.getLooper());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -457,22 +462,15 @@ public class CameraConnectionFragment extends Fragment {
|
||||
*/
|
||||
private void stopBackgroundThread() {
|
||||
backgroundThread.quitSafely();
|
||||
inferenceThread.quitSafely();
|
||||
try {
|
||||
backgroundThread.join();
|
||||
backgroundThread = null;
|
||||
backgroundHandler = null;
|
||||
|
||||
inferenceThread.join();
|
||||
inferenceThread = null;
|
||||
inferenceThread = null;
|
||||
} catch (final InterruptedException e) {
|
||||
LOGGER.e(e, "Exception!");
|
||||
}
|
||||
}
|
||||
|
||||
private final TensorFlowImageListener tfPreviewListener = new TensorFlowImageListener();
|
||||
|
||||
private final CameraCaptureSession.CaptureCallback captureCallback =
|
||||
new CameraCaptureSession.CaptureCallback() {
|
||||
@Override
|
||||
@ -513,7 +511,7 @@ public class CameraConnectionFragment extends Fragment {
|
||||
ImageReader.newInstance(
|
||||
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
|
||||
|
||||
previewReader.setOnImageAvailableListener(tfPreviewListener, backgroundHandler);
|
||||
previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
|
||||
previewRequestBuilder.addTarget(previewReader.getSurface());
|
||||
|
||||
// Here, we create a CameraCaptureSession for camera preview.
|
||||
@ -557,11 +555,6 @@ public class CameraConnectionFragment extends Fragment {
|
||||
} catch (final CameraAccessException e) {
|
||||
LOGGER.e(e, "Exception!");
|
||||
}
|
||||
|
||||
LOGGER.i("Getting assets.");
|
||||
tfPreviewListener.initialize(
|
||||
classifier, resultsView, inputSize, inferenceHandler, sensorOrientation);
|
||||
LOGGER.i("TensorFlow initialized.");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -16,12 +16,29 @@
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Bitmap.Config;
|
||||
import android.graphics.Canvas;
|
||||
import android.graphics.Matrix;
|
||||
import android.graphics.Paint;
|
||||
import android.media.Image;
|
||||
import android.media.Image.Plane;
|
||||
import android.media.ImageReader;
|
||||
import android.media.ImageReader.OnImageAvailableListener;
|
||||
import android.os.SystemClock;
|
||||
import android.os.Trace;
|
||||
import android.util.Size;
|
||||
import android.util.TypedValue;
|
||||
import android.view.Display;
|
||||
import java.io.IOException;
|
||||
|
||||
import android.app.Fragment;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
import org.tensorflow.demo.OverlayView.DrawCallback;
|
||||
import org.tensorflow.demo.env.BorderedText;
|
||||
import org.tensorflow.demo.env.ImageUtils;
|
||||
import org.tensorflow.demo.env.Logger;
|
||||
|
||||
public class ClassifierActivity extends CameraActivity {
|
||||
public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
|
||||
private static final Logger LOGGER = new Logger();
|
||||
|
||||
// These are the settings for the original v1 Inception model. If you want to
|
||||
@ -41,9 +58,58 @@ public class ClassifierActivity extends CameraActivity {
|
||||
private static final String LABEL_FILE =
|
||||
"file:///android_asset/imagenet_comp_graph_label_strings.txt";
|
||||
|
||||
private static final boolean SAVE_PREVIEW_BITMAP = false;
|
||||
|
||||
private static final boolean MAINTAIN_ASPECT = true;
|
||||
|
||||
private TensorFlowImageClassifier classifier;
|
||||
|
||||
private Integer sensorOrientation;
|
||||
|
||||
private int previewWidth = 0;
|
||||
private int previewHeight = 0;
|
||||
private byte[][] yuvBytes;
|
||||
private int[] rgbBytes = null;
|
||||
private Bitmap rgbFrameBitmap = null;
|
||||
private Bitmap croppedBitmap = null;
|
||||
|
||||
private Bitmap cropCopyBitmap;
|
||||
|
||||
private boolean computing = false;
|
||||
|
||||
private long timestamp = 0;
|
||||
|
||||
private Matrix frameToCropTransform;
|
||||
private Matrix cropToFrameTransform;
|
||||
|
||||
private ResultsView resultsView;
|
||||
|
||||
private OverlayView overlayView;
|
||||
|
||||
private BorderedText borderedText;
|
||||
|
||||
private long lastProcessingTimeMs;
|
||||
|
||||
@Override
|
||||
protected Fragment createFragment() {
|
||||
final TensorFlowImageClassifier classifier = new TensorFlowImageClassifier();
|
||||
protected int getLayoutId() {
|
||||
return R.layout.camera_connection_fragment;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getDesiredPreviewFrameSize() {
|
||||
return INPUT_SIZE;
|
||||
}
|
||||
|
||||
private static final float TEXT_SIZE_DIP = 18;
|
||||
|
||||
@Override
|
||||
public void onPreviewSizeChosen(final Size size, final int rotation) {
|
||||
final float textSizePx = TypedValue.applyDimension(
|
||||
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
|
||||
getResources().getDisplayMetrics());
|
||||
borderedText = new BorderedText(textSizePx);
|
||||
|
||||
classifier = new TensorFlowImageClassifier();
|
||||
try {
|
||||
classifier.initializeTensorFlow(
|
||||
getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
|
||||
@ -52,7 +118,151 @@ public class ClassifierActivity extends CameraActivity {
|
||||
LOGGER.e(e, "Exception!");
|
||||
}
|
||||
|
||||
return CameraConnectionFragment.newInstance(
|
||||
classifier, R.layout.camera_connection_fragment, INPUT_SIZE);
|
||||
overlayView = (OverlayView) findViewById(R.id.overlay);
|
||||
resultsView = (ResultsView) findViewById(R.id.results);
|
||||
previewWidth = size.getWidth();
|
||||
previewHeight = size.getHeight();
|
||||
|
||||
final Display display = getWindowManager().getDefaultDisplay();
|
||||
final int screenOrientation = display.getRotation();
|
||||
|
||||
LOGGER.i("Sensor orientation: %d, Screen orientation: %d",
|
||||
rotation, screenOrientation);
|
||||
|
||||
sensorOrientation = rotation + screenOrientation;
|
||||
|
||||
if (sensorOrientation % 180 == 90) {
|
||||
overlayView.setAspectRatio(size.getHeight(), size.getWidth());
|
||||
} else {
|
||||
overlayView.setAspectRatio(size.getWidth(), size.getHeight());
|
||||
}
|
||||
|
||||
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
|
||||
rgbBytes = new int[previewWidth * previewHeight];
|
||||
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
|
||||
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
|
||||
|
||||
frameToCropTransform = ImageUtils.getTransformationMatrix(
|
||||
previewWidth, previewHeight,
|
||||
INPUT_SIZE, INPUT_SIZE,
|
||||
sensorOrientation, MAINTAIN_ASPECT);
|
||||
|
||||
cropToFrameTransform = new Matrix();
|
||||
frameToCropTransform.invert(cropToFrameTransform);
|
||||
|
||||
yuvBytes = new byte[3][];
|
||||
|
||||
overlayView.addCallback(new DrawCallback() {
|
||||
@Override
|
||||
public void drawCallback(final Canvas canvas) {
|
||||
renderDebug(canvas);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onImageAvailable(final ImageReader reader) {
|
||||
Image image = null;
|
||||
|
||||
++timestamp;
|
||||
|
||||
try {
|
||||
image = reader.acquireLatestImage();
|
||||
|
||||
if (image == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (computing) {
|
||||
image.close();
|
||||
return;
|
||||
}
|
||||
computing = true;
|
||||
|
||||
Trace.beginSection("imageAvailable");
|
||||
|
||||
final Plane[] planes = image.getPlanes();
|
||||
fillBytes(planes, yuvBytes);
|
||||
|
||||
final int yRowStride = planes[0].getRowStride();
|
||||
final int uvRowStride = planes[1].getRowStride();
|
||||
final int uvPixelStride = planes[1].getPixelStride();
|
||||
ImageUtils.convertYUV420ToARGB8888(
|
||||
yuvBytes[0],
|
||||
yuvBytes[1],
|
||||
yuvBytes[2],
|
||||
rgbBytes,
|
||||
previewWidth,
|
||||
previewHeight,
|
||||
yRowStride,
|
||||
uvRowStride,
|
||||
uvPixelStride,
|
||||
false);
|
||||
|
||||
image.close();
|
||||
} catch (final Exception e) {
|
||||
if (image != null) {
|
||||
image.close();
|
||||
}
|
||||
LOGGER.e(e, "Exception!");
|
||||
Trace.endSection();
|
||||
return;
|
||||
}
|
||||
|
||||
rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
|
||||
final Canvas canvas = new Canvas(croppedBitmap);
|
||||
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
|
||||
|
||||
// For examining the actual TF input.
|
||||
if (SAVE_PREVIEW_BITMAP) {
|
||||
ImageUtils.saveBitmap(croppedBitmap);
|
||||
}
|
||||
|
||||
runInBackground(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
final long startTime = SystemClock.uptimeMillis();
|
||||
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
|
||||
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
|
||||
|
||||
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
|
||||
resultsView.setResults(results);
|
||||
overlayView.postInvalidate();
|
||||
computing = false;
|
||||
}
|
||||
});
|
||||
|
||||
Trace.endSection();
|
||||
}
|
||||
|
||||
private void renderDebug(Canvas canvas) {
|
||||
if (!isDebug()) {
|
||||
return;
|
||||
}
|
||||
final Bitmap copy = cropCopyBitmap;
|
||||
if (copy != null) {
|
||||
final Matrix matrix = new Matrix();
|
||||
final float scaleFactor = 2;
|
||||
matrix.postScale(scaleFactor, scaleFactor);
|
||||
matrix.postTranslate(
|
||||
canvas.getWidth() - copy.getWidth() * scaleFactor,
|
||||
canvas.getHeight() - copy.getHeight() * scaleFactor);
|
||||
canvas.drawBitmap(copy, matrix, new Paint());
|
||||
|
||||
final Vector<String> lines = new Vector<String>();
|
||||
lines.add("Frame: " + previewWidth + "x" + previewHeight);
|
||||
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
|
||||
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
|
||||
lines.add("Rotation: " + sensorOrientation);
|
||||
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
|
||||
|
||||
int lineNum = 0;
|
||||
for (final String line : lines) {
|
||||
borderedText.drawText(canvas, 10,
|
||||
canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, line);
|
||||
++lineNum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,94 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Canvas;
|
||||
import android.util.AttributeSet;
|
||||
import android.view.MotionEvent;
|
||||
import android.view.View;
|
||||
import android.view.View.MeasureSpec;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple View providing a render callback to other classes.
|
||||
*/
|
||||
public class OverlayView extends View {
|
||||
public OverlayView(final Context context, final AttributeSet attrs) {
|
||||
super(context, attrs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface defining the callback for client classes.
|
||||
*/
|
||||
public interface DrawCallback {
|
||||
public void drawCallback(final Canvas canvas);
|
||||
}
|
||||
|
||||
private int ratioWidth;
|
||||
private int ratioHeight;
|
||||
|
||||
private boolean debug;
|
||||
|
||||
private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
|
||||
|
||||
@Override
|
||||
public boolean onTouchEvent(final MotionEvent e) {
|
||||
super.onTouchEvent(e);
|
||||
if (e.getAction() == MotionEvent.ACTION_DOWN) {
|
||||
debug = !debug;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public void addCallback(final DrawCallback callback) {
|
||||
callbacks.add(callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void draw(final Canvas canvas) {
|
||||
for (final DrawCallback callback : callbacks) {
|
||||
callback.drawCallback(canvas);
|
||||
}
|
||||
}
|
||||
|
||||
public void setAspectRatio(final int width, final int height) {
|
||||
if (width < 0 || height < 0) {
|
||||
throw new IllegalArgumentException("Size cannot be negative.");
|
||||
}
|
||||
ratioWidth = width;
|
||||
ratioHeight = height;
|
||||
requestLayout();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
|
||||
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
|
||||
final int width = MeasureSpec.getSize(widthMeasureSpec);
|
||||
final int height = MeasureSpec.getSize(heightMeasureSpec);
|
||||
if (0 == ratioWidth || 0 == ratioHeight) {
|
||||
setMeasuredDimension(width, height);
|
||||
} else {
|
||||
if (width < height * ratioWidth / ratioHeight) {
|
||||
setMeasuredDimension(width, width * ratioHeight / ratioWidth);
|
||||
} else {
|
||||
setMeasuredDimension(height * ratioWidth / ratioHeight, height);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
119
tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java
vendored
Normal file
119
tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java
vendored
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo.env;
|
||||
|
||||
import android.graphics.Canvas;
|
||||
import android.graphics.Color;
|
||||
import android.graphics.Paint;
|
||||
import android.graphics.Paint.Align;
|
||||
import android.graphics.Paint.Style;
|
||||
import android.graphics.Rect;
|
||||
|
||||
/**
|
||||
* A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
|
||||
*/
|
||||
public class BorderedText {
|
||||
private final Paint interiorPaint;
|
||||
private final Paint exteriorPaint;
|
||||
|
||||
private final float textSize;
|
||||
|
||||
/**
|
||||
* Creates a left-aligned bordered text object with a white interior, and a black exterior with
|
||||
* the specified text size.
|
||||
*
|
||||
* @param textSize text size in pixels
|
||||
*/
|
||||
public BorderedText(final float textSize) {
|
||||
this(Color.WHITE, Color.BLACK, textSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a bordered text object with the specified interior and exterior colors, text size and
|
||||
* alignment.
|
||||
*
|
||||
* @param interiorColor the interior text color
|
||||
* @param exteriorColor the exterior text color
|
||||
* @param textSize text size in pixels
|
||||
*/
|
||||
public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
|
||||
interiorPaint = new Paint();
|
||||
interiorPaint.setTextSize(textSize);
|
||||
interiorPaint.setColor(interiorColor);
|
||||
interiorPaint.setStyle(Style.FILL);
|
||||
interiorPaint.setAntiAlias(false);
|
||||
interiorPaint.setAlpha(255);
|
||||
|
||||
exteriorPaint = new Paint();
|
||||
exteriorPaint.setTextSize(textSize);
|
||||
exteriorPaint.setColor(exteriorColor);
|
||||
exteriorPaint.setStyle(Style.FILL_AND_STROKE);
|
||||
exteriorPaint.setStrokeWidth(textSize / 8);
|
||||
exteriorPaint.setAntiAlias(false);
|
||||
exteriorPaint.setAlpha(255);
|
||||
|
||||
this.textSize = textSize;
|
||||
}
|
||||
|
||||
public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
|
||||
/*
|
||||
if (widths == null || widths.length < text.length()) {
|
||||
widths = new float[text.length()];
|
||||
positions = new float[text.length() * 2];
|
||||
}
|
||||
|
||||
exteriorPaint.getTextWidths(text, widths);
|
||||
float lastPosX = posX;
|
||||
for (int i = 0; i < widths.length; ++i) {
|
||||
positions[i * 2] = lastPosX;
|
||||
positions[i * 2 + 1] = posY;
|
||||
lastPosX += widths[i];
|
||||
}
|
||||
*/
|
||||
|
||||
//canvas.drawPosText(text, positions, exteriorPaint);
|
||||
//canvas.drawPosText(text, positions, exteriorPaint);
|
||||
canvas.drawText(text, posX, posY, exteriorPaint);
|
||||
canvas.drawText(text, posX, posY, interiorPaint);
|
||||
}
|
||||
|
||||
public void setInteriorColor(final int color) {
|
||||
interiorPaint.setColor(color);
|
||||
}
|
||||
|
||||
public void setExteriorColor(final int color) {
|
||||
exteriorPaint.setColor(color);
|
||||
}
|
||||
|
||||
public float getTextSize() {
|
||||
return textSize;
|
||||
}
|
||||
|
||||
public void setAlpha(final int alpha) {
|
||||
interiorPaint.setAlpha(alpha);
|
||||
exteriorPaint.setAlpha(alpha);
|
||||
}
|
||||
|
||||
public void getTextBounds(
|
||||
final String line, final int index, final int count, final Rect lineBounds) {
|
||||
interiorPaint.getTextBounds(line, index, count, lineBounds);
|
||||
}
|
||||
|
||||
public void setTextAlign(final Align align) {
|
||||
interiorPaint.setTextAlign(align);
|
||||
exteriorPaint.setTextAlign(align);
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user