Merge commit for internal changes
This commit is contained in:
commit
b96158c08d
@ -62,8 +62,6 @@ cc_library(
|
|||||||
# This define (mostly) guarantees we don't link any problematic
|
# This define (mostly) guarantees we don't link any problematic
|
||||||
# code. We use it, but we do not rely on it, as evidenced above.
|
# code. We use it, but we do not rely on it, as evidenced above.
|
||||||
"EIGEN_MPL2_ONLY",
|
"EIGEN_MPL2_ONLY",
|
||||||
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
|
|
||||||
# eigen_initialize.cc file and alwayslink=1.
|
|
||||||
],
|
],
|
||||||
includes = ["."],
|
includes = ["."],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
@ -105,6 +105,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/framework:all_files",
|
"//tensorflow/contrib/framework:all_files",
|
||||||
"//tensorflow/contrib/graph_editor:all_files",
|
"//tensorflow/contrib/graph_editor:all_files",
|
||||||
"//tensorflow/contrib/grid_rnn:all_files",
|
"//tensorflow/contrib/grid_rnn:all_files",
|
||||||
|
"//tensorflow/contrib/integrate:all_files",
|
||||||
"//tensorflow/contrib/layers:all_files",
|
"//tensorflow/contrib/layers:all_files",
|
||||||
"//tensorflow/contrib/layers/kernels:all_files",
|
"//tensorflow/contrib/layers/kernels:all_files",
|
||||||
"//tensorflow/contrib/learn:all_files",
|
"//tensorflow/contrib/learn:all_files",
|
||||||
@ -148,7 +149,6 @@ filegroup(
|
|||||||
"//tensorflow/examples/image_retraining:all_files",
|
"//tensorflow/examples/image_retraining:all_files",
|
||||||
"//tensorflow/examples/label_image:all_files",
|
"//tensorflow/examples/label_image:all_files",
|
||||||
"//tensorflow/examples/learn:all_files",
|
"//tensorflow/examples/learn:all_files",
|
||||||
"//tensorflow/examples/skflow:all_files",
|
|
||||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||||
|
@ -264,6 +264,36 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "nn_grad",
|
||||||
|
srcs = ["gradients/nn_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients_nn_grad_test",
|
||||||
|
srcs = ["gradients/nn_grad_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":grad_testutil",
|
||||||
|
":gradient_checker",
|
||||||
|
":nn_grad",
|
||||||
|
":testutil",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "cc_ops",
|
name = "cc_ops",
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
|
@ -110,20 +110,15 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
Status ComputeGradientErrorInternal(const Scope& scope, const ops::Output& x,
|
||||||
const TensorShape& x_shape, const ops::Output& y,
|
const TensorShape& x_shape,
|
||||||
const TensorShape& y_shape, T* max_error) {
|
const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, Tensor* x_data,
|
||||||
|
T* max_error) {
|
||||||
const int64 x_size = x_shape.num_elements();
|
const int64 x_size = x_shape.num_elements();
|
||||||
const int64 y_size = y_shape.num_elements();
|
const int64 y_size = y_shape.num_elements();
|
||||||
|
|
||||||
// Initialize 'x_data' to random values.
|
|
||||||
Tensor x_data(x.type(), x_shape);
|
|
||||||
auto x_data_flat = x_data.flat<T>();
|
|
||||||
x_data_flat.setRandom();
|
|
||||||
|
|
||||||
// Initialize theoretical Jacobian to zeros.
|
// Initialize theoretical Jacobian to zeros.
|
||||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||||
@ -131,7 +126,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
|
|
||||||
// Compute theoretical Jacobian.
|
// Compute theoretical Jacobian.
|
||||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
scope, x, x_shape, *x_data, y, y_shape, &jacobian_t));
|
||||||
|
|
||||||
// Initialize numeric Jacobian to zeros.
|
// Initialize numeric Jacobian to zeros.
|
||||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||||
@ -140,7 +135,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
|
|
||||||
// Compute numeric Jacobian.
|
// Compute numeric Jacobian.
|
||||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
scope, x, x_shape, y, y_shape, 1e-3, x_data, &jacobian_n));
|
||||||
|
|
||||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||||
*max_error = 0.0;
|
*max_error = 0.0;
|
||||||
@ -154,10 +149,39 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const TensorShape& x_shape, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error) {
|
||||||
|
// Initialize 'x_data' to random values.
|
||||||
|
Tensor x_data(x.type(), x_shape);
|
||||||
|
auto x_data_flat = x_data.flat<T>();
|
||||||
|
x_data_flat.setRandom();
|
||||||
|
// Compute gradient error.
|
||||||
|
return ComputeGradientErrorInternal(scope, x, x_shape, y, y_shape, &x_data,
|
||||||
|
max_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const Tensor& x_init_value, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error) {
|
||||||
|
// Initialize 'x_data' from 'x_init_value'.
|
||||||
|
Tensor x_data(x_init_value);
|
||||||
|
// Compute gradient error.
|
||||||
|
return ComputeGradientErrorInternal(scope, x, x_data.shape(), y, y_shape,
|
||||||
|
&x_data, max_error);
|
||||||
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||||
template Status ComputeGradientError<T>( \
|
template Status ComputeGradientError<T>( \
|
||||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
const ops::Output& y, const TensorShape& y_shape, T* max_error); \
|
||||||
|
template Status ComputeGradientError<T>( \
|
||||||
|
const Scope& scope, const ops::Output& x, const Tensor& x_init_value, \
|
||||||
|
const ops::Output& y, const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||||
|
@ -30,6 +30,12 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
const TensorShape& x_shape, const ops::Output& y,
|
const TensorShape& x_shape, const ops::Output& y,
|
||||||
const TensorShape& y_shape, T* max_error);
|
const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
|
// Overload of ComputeGradientError which takes an initial value for 'x'.
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const Tensor& x_init_value, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||||
|
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status SoftmaxGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
// Softmax gradient function.
|
||||||
|
// p = softmax(x) maps from [batch, n] to [batch, m]
|
||||||
|
// dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
|
||||||
|
// [ ... ... ]
|
||||||
|
// [dpm-1/dx0 ... dpm-1/dxn-1]
|
||||||
|
// dL/dx = dp/dx * dL/dy
|
||||||
|
//
|
||||||
|
// Using alternative formula:
|
||||||
|
// dL/dx = dL/dy * y - sum(dL/dy * y) * y
|
||||||
|
// = (dL/dy - sum(dL/dy * y)) * y
|
||||||
|
auto y = op.output(0);
|
||||||
|
auto dyy = Mul(scope, grad_inputs[0], y);
|
||||||
|
auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
|
||||||
|
auto sub = Sub(scope, grad_inputs[0], sum);
|
||||||
|
auto dx = Mul(scope, sub, y);
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
|
||||||
|
|
||||||
|
Status ReluGradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = ReluGrad(scope, grad_inputs[0], op.input(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
|
||||||
|
|
||||||
|
Status Relu6GradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = Relu6Grad(scope, grad_inputs[0], op.input(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
|
||||||
|
|
||||||
|
Status EluGradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = EluGrad(scope, grad_inputs[0], op.output(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||||
|
#include "tensorflow/cc/framework/testutil.h"
|
||||||
|
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class NNGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
NNGradTest() : scope_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||||
|
const TensorShape& y_shape) {
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||||
|
const TensorShape& y_shape) {
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, SoftmaxGrad) {
|
||||||
|
TensorShape shape({32, 10});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Softmax(scope_, x);
|
||||||
|
RunTest(x, shape, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, ReluGrad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Relu(scope_, x);
|
||||||
|
// Avoid input values where ReLU gradient is not well defined (around zero).
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, Relu6Grad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Relu6(scope_, x);
|
||||||
|
// Avoid input values where ReLU gradient is not well defined (around zero
|
||||||
|
// and six).
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, EluGrad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Elu(scope_, x);
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -23,6 +23,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/framework:framework_py",
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||||
|
"//tensorflow/contrib/integrate:integrate_py",
|
||||||
"//tensorflow/contrib/layers:layers_py",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
"//tensorflow/contrib/learn",
|
"//tensorflow/contrib/learn",
|
||||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
|
|||||||
from tensorflow.contrib import framework
|
from tensorflow.contrib import framework
|
||||||
from tensorflow.contrib import graph_editor
|
from tensorflow.contrib import graph_editor
|
||||||
from tensorflow.contrib import grid_rnn
|
from tensorflow.contrib import grid_rnn
|
||||||
|
from tensorflow.contrib import integrate
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib import learn
|
from tensorflow.contrib import learn
|
||||||
from tensorflow.contrib import linear_optimizer
|
from tensorflow.contrib import linear_optimizer
|
||||||
|
@ -76,7 +76,7 @@ def build_split_apply_merge_model():
|
|||||||
|
|
||||||
# REINFORCE forward step
|
# REINFORCE forward step
|
||||||
route_selection = st.StochasticTensor(
|
route_selection = st.StochasticTensor(
|
||||||
distributions.Categorical, logits=logits)
|
distributions.Categorical(logits=logits))
|
||||||
|
|
||||||
# Accessing route_selection as a Tensor below forces a sample of
|
# Accessing route_selection as a Tensor below forces a sample of
|
||||||
# the Categorical distribution based on its logits.
|
# the Categorical distribution based on its logits.
|
||||||
|
@ -22,6 +22,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
st = tf.contrib.bayesflow.stochastic_tensor
|
st = tf.contrib.bayesflow.stochastic_tensor
|
||||||
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
||||||
|
dists = tf.contrib.distributions
|
||||||
|
|
||||||
|
|
||||||
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||||
@ -31,7 +32,7 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
self._final_loss = tf.constant(3.2)
|
self._final_loss = tf.constant(3.2)
|
||||||
|
|
||||||
def _testScoreFunction(self, loss_fn, expected):
|
def _testScoreFunction(self, loss_fn, expected):
|
||||||
x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
|
x = st.StochasticTensor(dists.Bernoulli(p=self._p), loss_fn=loss_fn)
|
||||||
sf = x.loss(self._final_loss)
|
sf = x.loss(self._final_loss)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
@ -62,8 +63,8 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
def testScoreFunctionWithMeanBaseline(self):
|
def testScoreFunctionWithMeanBaseline(self):
|
||||||
ema_decay = 0.8
|
ema_decay = 0.8
|
||||||
num_steps = 6
|
num_steps = 6
|
||||||
x = st.BernoulliTensor(
|
x = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
sf = x.loss(self._final_loss)
|
sf = x.loss(self._final_loss)
|
||||||
@ -98,12 +99,12 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
||||||
ema_decay = 0.8
|
ema_decay = 0.8
|
||||||
x = st.BernoulliTensor(
|
x = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
y = st.BernoulliTensor(
|
y = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
sf_x = x.loss(self._final_loss)
|
sf_x = x.loss(self._final_loss)
|
||||||
|
@ -39,9 +39,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = [0.0, 0.1, 0.2]
|
mu = [0.0, 0.1, 0.2]
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(
|
||||||
distributions.Normal, mu=prior, sigma=sigma)
|
distributions.Normal(mu=prior, sigma=sigma))
|
||||||
self.assertTrue(prior.distribution.is_reparameterized)
|
self.assertTrue(prior.distribution.is_reparameterized)
|
||||||
self.assertTrue(likelihood.distribution.is_reparameterized)
|
self.assertTrue(likelihood.distribution.is_reparameterized)
|
||||||
|
|
||||||
@ -77,10 +77,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
|
||||||
NormalNotParam, mu=prior, sigma=sigma)
|
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
prior_2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
|
||||||
|
|
||||||
loss = tf.square(tf.identity(likelihood) - mu)
|
loss = tf.square(tf.identity(likelihood) - mu)
|
||||||
part_loss = tf.square(tf.identity(prior) - mu)
|
part_loss = tf.square(tf.identity(prior) - mu)
|
||||||
@ -155,9 +154,7 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
dt = st.StochasticTensor(NormalNotParam,
|
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
|
||||||
mu=mu,
|
|
||||||
sigma=sigma,
|
|
||||||
loss_fn=None)
|
loss_fn=None)
|
||||||
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
||||||
|
|
||||||
@ -166,8 +163,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
dt1 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
dt2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
||||||
|
|
||||||
sl_all = sg.surrogate_loss([loss])
|
sl_all = sg.surrogate_loss([loss])
|
||||||
@ -186,8 +183,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
class StochasticDependenciesMapTest(tf.test.TestCase):
|
class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testBuildsMapOfUpstreamNodes(self):
|
def testBuildsMapOfUpstreamNodes(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
dt2 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
out1 = dt1.value() + 1.
|
out1 = dt1.value() + 1.
|
||||||
out2 = dt2.value() + 2.
|
out2 = dt2.value() + 2.
|
||||||
x = out1 + out2
|
x = out1 + out2
|
||||||
@ -197,11 +194,11 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
|||||||
self.assertEqual(dep_map[dt2], set([x, y]))
|
self.assertEqual(dep_map[dt2], set([x, y]))
|
||||||
|
|
||||||
def testHandlesStackedStochasticNodes(self):
|
def testHandlesStackedStochasticNodes(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
out1 = dt1.value() + 1.
|
out1 = dt1.value() + 1.
|
||||||
dt2 = st.StochasticTensor(distributions.Normal, mu=out1, sigma=1.)
|
dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
|
||||||
x = dt2.value() + 2.
|
x = dt2.value() + 2.
|
||||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
y = dt3.value() * 3.
|
y = dt3.value() * 3.
|
||||||
dep_map = sg._stochastic_dependencies_map([x, y])
|
dep_map = sg._stochastic_dependencies_map([x, y])
|
||||||
self.assertEqual(dep_map[dt1], set([x]))
|
self.assertEqual(dep_map[dt1], set([x]))
|
||||||
@ -209,10 +206,10 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
|||||||
self.assertEqual(dep_map[dt3], set([y]))
|
self.assertEqual(dep_map[dt3], set([y]))
|
||||||
|
|
||||||
def testTraversesControlInputs(self):
|
def testTraversesControlInputs(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
logits = dt1.value() * 3.
|
logits = dt1.value() * 3.
|
||||||
dt2 = st.StochasticTensor(distributions.Bernoulli, logits=logits)
|
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
|
||||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
x = dt3.value()
|
x = dt3.value()
|
||||||
y = tf.ones((2, 2)) * 4.
|
y = tf.ones((2, 2)) * 4.
|
||||||
z = tf.ones((2, 2)) * 3.
|
z = tf.ones((2, 2)) * 3.
|
||||||
|
@ -35,19 +35,19 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
prior_default = st.StochasticTensor(
|
prior_default = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
||||||
prior_0 = st.StochasticTensor(
|
prior_0 = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
dist_value_type=st.SampleAndReshapeValue())
|
dist_value_type=st.SampleAndReshapeValue())
|
||||||
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(
|
||||||
distributions.Normal, mu=prior, sigma=sigma2)
|
distributions.Normal(mu=prior, sigma=sigma2))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
|
|
||||||
with st.value_type(st.MeanValue()):
|
with st.value_type(st.MeanValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
||||||
|
|
||||||
prior_mean = prior.mean()
|
prior_mean = prior.mean()
|
||||||
@ -94,7 +94,8 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior_single = st.StochasticTensor(
|
prior_single = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(
|
||||||
|
mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_single_value = prior_single.value()
|
prior_single_value = prior_single.value()
|
||||||
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
||||||
@ -104,7 +105,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
||||||
prior_double = st.StochasticTensor(
|
prior_double = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_double_value = prior_double.value()
|
prior_double_value = prior_double.value()
|
||||||
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
||||||
@ -119,7 +120,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleValue()):
|
with st.value_type(st.SampleValue()):
|
||||||
prior_single = st.StochasticTensor(
|
prior_single = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
||||||
|
|
||||||
prior_single_value = prior_single.value()
|
prior_single_value = prior_single.value()
|
||||||
@ -130,7 +131,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleValue(n=2)):
|
with st.value_type(st.SampleValue(n=2)):
|
||||||
prior_double = st.StochasticTensor(
|
prior_double = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_double_value = prior_double.value()
|
prior_double_value = prior_double.value()
|
||||||
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
||||||
@ -143,9 +144,9 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
mu = [0.0, -1.0, 1.0]
|
mu = [0.0, -1.0, 1.0]
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.MeanValue()):
|
with st.value_type(st.MeanValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
entropy = prior.entropy()
|
entropy = prior.entropy()
|
||||||
deep_entropy = prior.entropy()
|
deep_entropy = prior.distribution.entropy()
|
||||||
expected_deep_entropy = distributions.Normal(
|
expected_deep_entropy = distributions.Normal(
|
||||||
mu=mu, sigma=sigma).entropy()
|
mu=mu, sigma=sigma).entropy()
|
||||||
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
||||||
@ -159,17 +160,15 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# With default
|
# With default
|
||||||
with st.value_type(st.MeanValue(stop_gradient=True)):
|
with st.value_type(st.MeanValue(stop_gradient=True)):
|
||||||
dt = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
loss = dt.loss([tf.constant(2.0)])
|
loss = dt.loss([tf.constant(2.0)])
|
||||||
self.assertTrue(loss is not None)
|
self.assertTrue(loss is not None)
|
||||||
self.assertAllClose(dt.distribution.log_prob(mu).eval() * 2.0,
|
self.assertAllClose(
|
||||||
loss.eval())
|
dt.distribution.log_prob(mu).eval() * 2.0, loss.eval())
|
||||||
|
|
||||||
# With passed-in loss_fn.
|
# With passed-in loss_fn.
|
||||||
dt = st.StochasticTensor(
|
dt = st.StochasticTensor(
|
||||||
distributions.Normal,
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
mu=mu,
|
|
||||||
sigma=sigma,
|
|
||||||
dist_value_type=st.MeanValue(stop_gradient=True),
|
dist_value_type=st.MeanValue(stop_gradient=True),
|
||||||
loss_fn=sge.get_score_function_with_constant_baseline(
|
loss_fn=sge.get_score_function_with_constant_baseline(
|
||||||
baseline=tf.constant(8.0)))
|
baseline=tf.constant(8.0)))
|
||||||
@ -204,7 +203,7 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
obs = tf.zeros((2, 3))
|
obs = tf.zeros((2, 3))
|
||||||
z = st.ObservedStochasticTensor(
|
z = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||||
[obs_val, z_val] = sess.run([obs, z.value()])
|
[obs_val, z_val] = sess.run([obs, z.value()])
|
||||||
self.assertAllEqual(obs_val, z_val)
|
self.assertAllEqual(obs_val, z_val)
|
||||||
|
|
||||||
@ -216,13 +215,13 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.placeholder(tf.float32)
|
sigma = tf.placeholder(tf.float32)
|
||||||
obs = tf.placeholder(tf.float32)
|
obs = tf.placeholder(tf.float32)
|
||||||
z = st.ObservedStochasticTensor(
|
z = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||||
|
|
||||||
mu2 = tf.placeholder(tf.float32, shape=[None])
|
mu2 = tf.placeholder(tf.float32, shape=[None])
|
||||||
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
||||||
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||||
z2 = st.ObservedStochasticTensor(
|
z2 = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu2, sigma=sigma2, value=obs2)
|
distributions.Normal(mu=mu2, sigma=sigma2), value=obs2)
|
||||||
|
|
||||||
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
||||||
self.assertEqual(coll, [z, z2])
|
self.assertEqual(coll, [z, z2])
|
||||||
@ -230,27 +229,19 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
def testConstructionErrors(self):
|
def testConstructionErrors(self):
|
||||||
mu = [0., 0.]
|
mu = [0., 0.]
|
||||||
sigma = [1., 1.]
|
sigma = [1., 1.]
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
self.assertRaises(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
ValueError,
|
||||||
value=tf.zeros((3,)))
|
st.ObservedStochasticTensor,
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
value=tf.zeros((3,)))
|
||||||
value=tf.zeros((3, 1)))
|
self.assertRaises(
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
ValueError,
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
st.ObservedStochasticTensor,
|
||||||
value=tf.zeros((1, 2), dtype=tf.int32))
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
|
value=tf.zeros((3, 1)))
|
||||||
|
self.assertRaises(
|
||||||
class AutomaticDistributionImportTest(tf.test.TestCase):
|
ValueError,
|
||||||
|
st.ObservedStochasticTensor,
|
||||||
def testImportNormal(self):
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
self.assertTrue(hasattr(st, "NormalTensor"))
|
value=tf.zeros(
|
||||||
self.assertTrue(callable(st.NormalTensor))
|
(1, 2), dtype=tf.int32))
|
||||||
norm = st.NormalTensor(mu=0.0, sigma=1.0)
|
|
||||||
self.assertEqual(type(norm).__name__, "NormalTensor")
|
|
||||||
self.assertTrue(isinstance(norm, st.NormalTensor))
|
|
||||||
self.assertTrue(isinstance(norm, st.StochasticTensor))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
tf.test.main()
|
|
||||||
|
@ -44,7 +44,7 @@ def mini_vae():
|
|||||||
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
||||||
prior = distributions.Normal(mu=0., sigma=1.)
|
prior = distributions.Normal(mu=0., sigma=1.)
|
||||||
variational = st.StochasticTensor(
|
variational = st.StochasticTensor(
|
||||||
distributions.Normal, mu=inference_net(x, 1), sigma=1.)
|
distributions.Normal(mu=inference_net(x, 1), sigma=1.))
|
||||||
vi.register_prior(variational, prior)
|
vi.register_prior(variational, prior)
|
||||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||||
@ -101,7 +101,7 @@ class VariationalInferenceTest(tf.test.TestCase):
|
|||||||
|
|
||||||
prior = distributions.Bernoulli(0.5)
|
prior = distributions.Bernoulli(0.5)
|
||||||
variational = st.StochasticTensor(
|
variational = st.StochasticTensor(
|
||||||
NormalNoEntropy, mu=inference_net(x, 1), sigma=1.)
|
NormalNoEntropy(mu=inference_net(x, 1), sigma=1.))
|
||||||
vi.register_prior(variational, prior)
|
vi.register_prior(variational, prior)
|
||||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||||
|
@ -44,7 +44,6 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import six
|
import six
|
||||||
@ -79,10 +78,6 @@ class BaseStochasticTensor(object):
|
|||||||
def graph(self):
|
def graph(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def input_dict(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def value(self, name=None):
|
def value(self, name=None):
|
||||||
pass
|
pass
|
||||||
@ -120,6 +115,7 @@ class BaseStochasticTensor(object):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
ops.register_tensor_conversion_function(
|
ops.register_tensor_conversion_function(
|
||||||
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@ -223,8 +219,8 @@ class SampleAndReshapeValue(_StochasticValueType):
|
|||||||
st_value = st.value()
|
st_value = st.value()
|
||||||
assertEqual(st_value.get_shape(), (4, 3))
|
assertEqual(st_value.get_shape(), (4, 3))
|
||||||
|
|
||||||
dt_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||||
assertEqual(dt_value_val.shape, (4, 3))
|
assertEqual(st_value_val.shape, (4, 3))
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -312,17 +308,16 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dist_cls,
|
dist,
|
||||||
name=None,
|
name="StochasticTensor",
|
||||||
dist_value_type=None,
|
dist_value_type=None,
|
||||||
loss_fn=sge.score_function,
|
loss_fn=sge.score_function):
|
||||||
**dist_args):
|
|
||||||
"""Construct a `StochasticTensor`.
|
"""Construct a `StochasticTensor`.
|
||||||
|
|
||||||
`StochasticTensor` will instantiate a distribution from `dist_cls` and
|
`StochasticTensor` is backed by the `dist` distribution and its `value`
|
||||||
`dist_args` and its `value` method will return the same value each time
|
method will return the same value each time it is called. What `value` is
|
||||||
it is called. What `value` is returned is controlled by the
|
returned is controlled by the `dist_value_type` (defaults to
|
||||||
`dist_value_type` (defaults to `SampleAndReshapeValue`).
|
`SampleAndReshapeValue`).
|
||||||
|
|
||||||
Some distributions' sample functions are not differentiable (e.g. a sample
|
Some distributions' sample functions are not differentiable (e.g. a sample
|
||||||
from a discrete distribution like a Bernoulli) and so to differentiate
|
from a discrete distribution like a Bernoulli) and so to differentiate
|
||||||
@ -338,28 +333,25 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
`MeanValueType` or if `loss_fn=None`.
|
`MeanValueType` or if `loss_fn=None`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_cls: a `Distribution` class.
|
dist: an instance of `Distribution`.
|
||||||
name: a name for this `StochasticTensor` and its ops.
|
name: a name for this `StochasticTensor` and its ops.
|
||||||
dist_value_type: a `_StochasticValueType`, which will determine what the
|
dist_value_type: a `_StochasticValueType`, which will determine what the
|
||||||
`value` of this `StochasticTensor` will be. If not provided, the
|
`value` of this `StochasticTensor` will be. If not provided, the
|
||||||
value type set with the `value_type` context manager will be used.
|
value type set with the `value_type` context manager will be used.
|
||||||
loss_fn: callable that takes `(st, st.value(), influenced_loss)`, where
|
loss_fn: callable that takes
|
||||||
|
`(st, st.value(), influenced_loss)`, where
|
||||||
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
||||||
default, `loss_fn` is the `score_function`, or more precisely, the
|
default, `loss_fn` is the `score_function`, or more precisely, the
|
||||||
integral of the score function, such that when the gradient is taken,
|
integral of the score function, such that when the gradient is taken,
|
||||||
the score function results. See the `stochastic_gradient_estimators`
|
the score function results. See the `stochastic_gradient_estimators`
|
||||||
module for additional loss functions and baselines.
|
module for additional loss functions and baselines.
|
||||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
|
||||||
construction.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `dist_cls` is not a `Distribution`.
|
TypeError: if `dist` is not an instance of `Distribution`.
|
||||||
TypeError: if `loss_fn` is not `callable`.
|
TypeError: if `loss_fn` is not `callable`.
|
||||||
"""
|
"""
|
||||||
if not issubclass(dist_cls, distributions.Distribution):
|
if not isinstance(dist, distributions.Distribution):
|
||||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
raise TypeError("dist must be an instance of Distribution")
|
||||||
self._dist_cls = dist_cls
|
|
||||||
self._dist_args = dist_args
|
|
||||||
if dist_value_type is None:
|
if dist_value_type is None:
|
||||||
try:
|
try:
|
||||||
self._value_type = get_current_value_type()
|
self._value_type = get_current_value_type()
|
||||||
@ -371,24 +363,17 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
with value_type(dist_value_type):
|
with value_type(dist_value_type):
|
||||||
self._value_type = get_current_value_type()
|
self._value_type = get_current_value_type()
|
||||||
|
|
||||||
self._value_type.declare_inputs(self, dist_args)
|
|
||||||
|
|
||||||
if loss_fn is not None and not callable(loss_fn):
|
if loss_fn is not None and not callable(loss_fn):
|
||||||
raise TypeError("loss_fn must be callable")
|
raise TypeError("loss_fn must be callable")
|
||||||
self._loss_fn = loss_fn
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
with ops.name_scope(name, "StochasticTensor",
|
with ops.name_scope(name) as scope:
|
||||||
dist_args.values()) as scope:
|
|
||||||
self._name = scope
|
self._name = scope
|
||||||
self._dist = dist_cls(**dist_args)
|
self._dist = dist
|
||||||
self._value = self._create_value()
|
self._value = self._create_value()
|
||||||
|
|
||||||
super(StochasticTensor, self).__init__()
|
super(StochasticTensor, self).__init__()
|
||||||
|
|
||||||
@property
|
|
||||||
def input_dict(self):
|
|
||||||
return self._dist_args
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
return self._value_type
|
return self._value_type
|
||||||
@ -397,9 +382,6 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
def distribution(self):
|
def distribution(self):
|
||||||
return self._dist
|
return self._dist
|
||||||
|
|
||||||
def clone(self, name=None, **dist_args):
|
|
||||||
return StochasticTensor(self._dist_cls, name=name, **dist_args)
|
|
||||||
|
|
||||||
def _create_value(self):
|
def _create_value(self):
|
||||||
"""Create the value Tensor based on the value type, store as self._value."""
|
"""Create the value Tensor based on the value type, store as self._value."""
|
||||||
|
|
||||||
@ -494,33 +476,28 @@ class ObservedStochasticTensor(StochasticTensor):
|
|||||||
"""A StochasticTensor with an observed value."""
|
"""A StochasticTensor with an observed value."""
|
||||||
|
|
||||||
# pylint: disable=super-init-not-called
|
# pylint: disable=super-init-not-called
|
||||||
def __init__(self, dist_cls, value, name=None, **dist_args):
|
def __init__(self, dist, value, name=None):
|
||||||
"""Construct an `ObservedStochasticTensor`.
|
"""Construct an `ObservedStochasticTensor`.
|
||||||
|
|
||||||
`ObservedStochasticTensor` will instantiate a distribution from `dist_cls`
|
`ObservedStochasticTensor` is backed by distribution `dist` and uses the
|
||||||
and `dist_args` but use the provided value instead of sampling from the
|
provided value instead of using the current value type to draw a value from
|
||||||
distribution. The provided value argument must be appropriately shaped
|
the distribution. The provided value argument must be appropriately shaped
|
||||||
to have come from the constructed distribution.
|
to have come from the distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_cls: a `Distribution` class.
|
dist: an instance of `Distribution`.
|
||||||
value: a Tensor containing the observed value
|
value: a Tensor containing the observed value
|
||||||
name: a name for this `ObservedStochasticTensor` and its ops.
|
name: a name for this `ObservedStochasticTensor` and its ops.
|
||||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
|
||||||
construction.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `dist_cls` is not a `Distribution`.
|
TypeError: if `dist` is not an instance of `Distribution`.
|
||||||
ValueError: if `value` is not compatible with the distribution.
|
ValueError: if `value` is not compatible with the distribution.
|
||||||
"""
|
"""
|
||||||
if not issubclass(dist_cls, distributions.Distribution):
|
if not isinstance(dist, distributions.Distribution):
|
||||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
raise TypeError("dist must be an instance of Distribution")
|
||||||
self._dist_cls = dist_cls
|
with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope:
|
||||||
self._dist_args = dist_args
|
|
||||||
with ops.name_scope(name, "ObservedStochasticTensor",
|
|
||||||
list(dist_args.values()) + [value]) as scope:
|
|
||||||
self._name = scope
|
self._name = scope
|
||||||
self._dist = dist_cls(**dist_args)
|
self._dist = dist
|
||||||
dist_shape = self._dist.get_batch_shape().concatenate(
|
dist_shape = self._dist.get_batch_shape().concatenate(
|
||||||
self._dist.get_event_shape())
|
self._dist.get_event_shape())
|
||||||
value = ops.convert_to_tensor(value)
|
value = ops.convert_to_tensor(value)
|
||||||
@ -538,7 +515,7 @@ class ObservedStochasticTensor(StochasticTensor):
|
|||||||
"sample from the distribution %s." % (value_shape, dist_shape))
|
"sample from the distribution %s." % (value_shape, dist_shape))
|
||||||
if value.dtype != self._dist.dtype:
|
if value.dtype != self._dist.dtype:
|
||||||
raise ValueError("Type of observed value (%s) does not match type of "
|
raise ValueError("Type of observed value (%s) does not match type of "
|
||||||
"distribuiton (%s)." % (value.dtype, self._dist.dtype))
|
"distribution (%s)." % (value.dtype, self._dist.dtype))
|
||||||
self._value = array_ops.identity(value)
|
self._value = array_ops.identity(value)
|
||||||
# pylint: disable=non-parent-init-called
|
# pylint: disable=non-parent-init-called
|
||||||
BaseStochasticTensor.__init__(self)
|
BaseStochasticTensor.__init__(self)
|
||||||
@ -557,39 +534,3 @@ __all__ = [
|
|||||||
"value_type",
|
"value_type",
|
||||||
"get_current_value_type",
|
"get_current_value_type",
|
||||||
]
|
]
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
# pylint: disable=redefined-builtin
|
|
||||||
__doc__ += "\n\n## Automatically Generated StochasticTensors\n\n"
|
|
||||||
# pylint: enable=redefined-builtin
|
|
||||||
for _name in sorted(dir(distributions)):
|
|
||||||
_candidate = getattr(distributions, _name)
|
|
||||||
if (inspect.isclass(_candidate)
|
|
||||||
and _candidate != distributions.Distribution
|
|
||||||
and issubclass(_candidate, distributions.Distribution)):
|
|
||||||
_local_name = "%sTensor" % _name
|
|
||||||
|
|
||||||
class _WrapperTensor(StochasticTensor):
|
|
||||||
_my_candidate = _candidate
|
|
||||||
|
|
||||||
def __init__(self, name=None, dist_value_type=None,
|
|
||||||
loss_fn=sge.score_function, **dist_args):
|
|
||||||
StochasticTensor.__init__(
|
|
||||||
self,
|
|
||||||
dist_cls=self._my_candidate,
|
|
||||||
name=name,
|
|
||||||
dist_value_type=dist_value_type,
|
|
||||||
loss_fn=loss_fn, **dist_args)
|
|
||||||
|
|
||||||
_WrapperTensor.__name__ = _local_name
|
|
||||||
_WrapperTensor.__doc__ = (
|
|
||||||
"`%s` is a `StochasticTensor` backed by the distribution `%s`."""
|
|
||||||
% (_local_name, _name))
|
|
||||||
_globals[_local_name] = _WrapperTensor
|
|
||||||
del _WrapperTensor
|
|
||||||
del _candidate
|
|
||||||
|
|
||||||
__all__.append(_local_name)
|
|
||||||
__doc__ += "@@%s\n" % _local_name
|
|
||||||
|
|
||||||
del _local_name
|
|
||||||
|
@ -126,7 +126,7 @@ def get_stochastic_variable(getter,
|
|||||||
|
|
||||||
dist_kwargs = dist_kwargs or {}
|
dist_kwargs = dist_kwargs or {}
|
||||||
dist_kwargs.update(params)
|
dist_kwargs.update(params)
|
||||||
sample = st.StochasticTensor(dist_cls, **dist_kwargs)
|
sample = st.StochasticTensor(dist_cls(**dist_kwargs))
|
||||||
|
|
||||||
if prior is not None:
|
if prior is not None:
|
||||||
if callable(prior):
|
if callable(prior):
|
||||||
|
@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testCorrectlyMakesNoBatchLowerTril(self):
|
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = np.arange(9)
|
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||||
expected = np.array(
|
expected = np.array(
|
||||||
[[0., 0., 0.],
|
[[0., 0., 0.],
|
||||||
[1., 2., 0.],
|
[1., 2., 0.],
|
||||||
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
|||||||
actual = distribution_util.fill_lower_triangular(x)
|
actual = distribution_util.fill_lower_triangular(x)
|
||||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||||
self.assertAllEqual(expected, actual.eval())
|
self.assertAllEqual(expected, actual.eval())
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate([np.ones(6, dtype=np.float32),
|
||||||
|
np.zeros(3, dtype=np.float32)]),
|
||||||
|
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||||
|
|
||||||
def testCorrectlyMakesBatchLowerTril(self):
|
def testCorrectlyMakesBatchLowerTril(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, values=(x,)):
|
with ops.name_scope(name, values=(x,)):
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
ndims = x.get_shape().ndims
|
if (x.get_shape().ndims is not None and
|
||||||
if ndims is not None and x.get_shape()[-1].value is not None:
|
x.get_shape()[-1].value is not None):
|
||||||
d = x.get_shape()[-1].value
|
d = x.get_shape()[-1].value
|
||||||
# d = n^2/2 + n/2 implies n is:
|
# d = n^2/2 + n/2 implies n is:
|
||||||
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||||
final_shape = x.get_shape()[:-1].concatenate(
|
final_shape = x.get_shape()[:-1].concatenate(
|
||||||
tensor_shape.TensorShape([n, n]))
|
tensor_shape.TensorShape([n, n]))
|
||||||
else:
|
else:
|
||||||
ndims = array_ops.rank(x)
|
|
||||||
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||||
# d = n^2/2 + n/2 implies n is:
|
# d = n^2/2 + n/2 implies n is:
|
||||||
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||||
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
|||||||
array_ops.tile([tril_ids], [m, 1])])
|
array_ops.tile([tril_ids], [m, 1])])
|
||||||
idx = array_ops.transpose(idx, [1, 2, 0])
|
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||||
|
|
||||||
y = array_ops.gather_nd(y, idx)
|
if x.get_shape().ndims == 1:
|
||||||
|
# Prefer using gather because it has a gradient.
|
||||||
|
# We wrap the result in a list so downstream logic "just works."
|
||||||
|
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||||
|
else:
|
||||||
|
y = array_ops.gather_nd(y, idx)
|
||||||
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||||
|
|
||||||
y.set_shape(y.get_shape().merge_with(final_shape))
|
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||||
|
@ -571,9 +571,8 @@ class WALSModel(object):
|
|||||||
extras = size % num_shards
|
extras = size % num_shards
|
||||||
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
||||||
(ids - extras) // ids_per_shard)
|
(ids - extras) // ids_per_shard)
|
||||||
new_ids = tf.select(assignments < extras,
|
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
|
||||||
ids % (ids_per_shard + 1),
|
(ids - extras) % ids_per_shard)
|
||||||
(ids - extras) % ids_per_shard)
|
|
||||||
return assignments, new_ids
|
return assignments, new_ids
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
variables = tf.local_variables()
|
variables = tf.local_variables()
|
||||||
self.assertEquals(2, len(variables))
|
self.assertEquals(2, len(variables))
|
||||||
self.assertRaises(tf.OpError, sess.run, variables)
|
self.assertRaises(tf.OpError, sess.run, variables)
|
||||||
tf.initialize_variables(variables).run()
|
tf.variables_initializer(variables).run()
|
||||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||||
|
|
||||||
def testLocalVariableNameAndShape(self):
|
def testLocalVariableNameAndShape(self):
|
||||||
@ -51,7 +51,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
with tf.variable_scope('A'):
|
with tf.variable_scope('A'):
|
||||||
a = tf.contrib.framework.local_variable(0)
|
a = tf.contrib.framework.local_variable(0)
|
||||||
self.assertFalse(a in tf.all_variables())
|
self.assertFalse(a in tf.global_variables())
|
||||||
self.assertTrue(a in tf.local_variables())
|
self.assertTrue(a in tf.local_variables())
|
||||||
|
|
||||||
def testLocalVariableNotInVariablesToRestore(self):
|
def testLocalVariableNotInVariablesToRestore(self):
|
||||||
@ -82,7 +82,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
def testInitializedVariableValue(self):
|
def testInitializedVariableValue(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||||
sess.run(tf.initialize_local_variables())
|
sess.run(tf.local_variables_initializer())
|
||||||
self.assertAllEqual(a.eval(), [0]*5)
|
self.assertAllEqual(a.eval(), [0]*5)
|
||||||
|
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
with tf.variable_scope('A'):
|
with tf.variable_scope('A'):
|
||||||
a = tf.contrib.framework.model_variable('a', [5])
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
self.assertTrue(a in tf.all_variables())
|
self.assertTrue(a in tf.global_variables())
|
||||||
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
||||||
self.assertFalse(a in tf.local_variables())
|
self.assertFalse(a in tf.local_variables())
|
||||||
|
|
||||||
@ -474,7 +474,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
a = tf.contrib.framework.model_variable(
|
a = tf.contrib.framework.model_variable(
|
||||||
'a', [5], initializer=tf.ones_initializer)
|
'a', [5], initializer=tf.ones_initializer)
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
self.assertAllEqual(a.eval(), [1]*5)
|
self.assertAllEqual(a.eval(), [1]*5)
|
||||||
|
|
||||||
def testDeviceFn(self):
|
def testDeviceFn(self):
|
||||||
@ -667,7 +667,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
|||||||
var_names_to_values)
|
var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(assign_op, feed_dict)
|
sess.run(assign_op, feed_dict)
|
||||||
@ -697,7 +697,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
|||||||
var_names_to_values)
|
var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(assign_op, feed_dict)
|
sess.run(assign_op, feed_dict)
|
||||||
@ -725,7 +725,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
|||||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -754,7 +754,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
|||||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -786,7 +786,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
var_value = var_names_to_values[var_name]
|
var_value = var_names_to_values[var_name]
|
||||||
var_list.append(tf.Variable(var_value, name=var_name))
|
var_list.append(tf.Variable(var_value, name=var_name))
|
||||||
saver = tf.train.Saver(var_list)
|
saver = tf.train.Saver(var_list)
|
||||||
init_op = tf.initialize_variables(var_list)
|
init_op = tf.variables_initializer(var_list)
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
# Save the initialized values in the file at 'checkpoint_dir'
|
# Save the initialized values in the file at 'checkpoint_dir'
|
||||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||||
@ -808,7 +808,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(op, feed_dict)
|
sess.run(op, feed_dict)
|
||||||
@ -859,7 +859,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
vars_to_restore)
|
vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(op, feed_dict)
|
sess.run(op, feed_dict)
|
||||||
@ -890,7 +890,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
var_value = var_names_to_values[var_name]
|
var_value = var_names_to_values[var_name]
|
||||||
var_list.append(tf.Variable(var_value, name=var_name))
|
var_list.append(tf.Variable(var_value, name=var_name))
|
||||||
saver = tf.train.Saver(var_list)
|
saver = tf.train.Saver(var_list)
|
||||||
init_op = tf.initialize_variables(var_list)
|
init_op = tf.variables_initializer(var_list)
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
# Save the initialized values in the file at 'checkpoint_dir'
|
# Save the initialized values in the file at 'checkpoint_dir'
|
||||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||||
@ -912,7 +912,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -938,7 +938,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
@ -961,7 +961,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore, reshape_variables=True)
|
model_path, vars_to_restore, reshape_variables=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -989,7 +989,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
vars_to_restore)
|
vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
with self.assertRaises(tf.errors.NotFoundError):
|
with self.assertRaises(tf.errors.NotFoundError):
|
||||||
@ -1015,7 +1015,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
ignore_missing_vars=True)
|
ignore_missing_vars=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -1044,7 +1044,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
ignore_missing_vars=True)
|
ignore_missing_vars=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
|
38
tensorflow/contrib/integrate/BUILD
Normal file
38
tensorflow/contrib/integrate/BUILD
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Description:
|
||||||
|
# Integration and ODE solvers for TensorFlow.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "integrate_py",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
"python/ops/odes.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "odes_test",
|
||||||
|
srcs = ["python/ops/odes_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":integrate_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
9
tensorflow/contrib/integrate/README.md
Normal file
9
tensorflow/contrib/integrate/README.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Integration and ODE solvers for TensorFlow
|
||||||
|
|
||||||
|
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
|
||||||
|
contains a single function, `odeint`, for integrating ordinary differential
|
||||||
|
equations.
|
||||||
|
|
||||||
|
Maintainers:
|
||||||
|
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||||
|
- Marc Coram (mcoram@google.com, github.com/mcoram)
|
64
tensorflow/contrib/integrate/__init__.py
Normal file
64
tensorflow/contrib/integrate/__init__.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Integration and ODE solvers for TensorFlow.
|
||||||
|
|
||||||
|
## Example: Lorenz attractor
|
||||||
|
|
||||||
|
We can use `odeint` to solve the
|
||||||
|
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||||
|
differential equations, a prototypical example of chaotic dynamics:
|
||||||
|
|
||||||
|
```python
|
||||||
|
rho = 28.0
|
||||||
|
sigma = 10.0
|
||||||
|
beta = 8.0/3.0
|
||||||
|
|
||||||
|
def lorenz_equation(state, t):
|
||||||
|
x, y, z = tf.unpack(state)
|
||||||
|
dx = sigma * (y - x)
|
||||||
|
dy = x * (rho - z) - y
|
||||||
|
dz = x * y - beta * z
|
||||||
|
return tf.pack([dx, dy, dz])
|
||||||
|
|
||||||
|
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||||
|
t = np.linspace(0, 50, num=5000)
|
||||||
|
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||||
|
lorenz_equation, init_state, t, full_output=True)
|
||||||
|
|
||||||
|
sess = tf.Session()
|
||||||
|
state, info = sess.run([tensor_state, tensor_info])
|
||||||
|
x, y, z = state.T
|
||||||
|
plt.plot(x, z)
|
||||||
|
```
|
||||||
|
|
||||||
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Ops
|
||||||
|
|
||||||
|
@@odeint
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||||
|
from tensorflow.python.util.all_util import make_all
|
||||||
|
|
||||||
|
__all__ = make_all(__name__)
|
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
@ -0,0 +1,503 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""ODE solvers for TensorFlow."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
|
|
||||||
|
|
||||||
|
_ButcherTableau = collections.namedtuple(
|
||||||
|
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
|
||||||
|
|
||||||
|
# Parameters from Shampine (1986), section 4.
|
||||||
|
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
|
||||||
|
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
|
||||||
|
beta=[[1/5],
|
||||||
|
[3/40, 9/40],
|
||||||
|
[44/45, -56/15, 32/9],
|
||||||
|
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||||
|
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||||
|
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
|
||||||
|
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
|
||||||
|
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
|
||||||
|
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
|
||||||
|
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
|
||||||
|
c_error=[1951/21600 - 35/384,
|
||||||
|
0,
|
||||||
|
22642/50085 - 500/1113,
|
||||||
|
451/720 - 125/192,
|
||||||
|
-12231/42400 - -2187/6784,
|
||||||
|
649/6300 - 11/84,
|
||||||
|
1/60],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _possibly_nonzero(x):
|
||||||
|
return isinstance(x, ops.Tensor) or x != 0
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product(scale, xs, ys, name=None):
|
||||||
|
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
||||||
|
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
|
||||||
|
# Some of the parameters in our Butcher tableau include zeros. Using
|
||||||
|
# _possibly_nonzero lets us avoid wasted computation.
|
||||||
|
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
|
||||||
|
if _possibly_nonzero(x) or _possibly_nonzero(y)],
|
||||||
|
name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _dot_product(xs, ys, name=None):
|
||||||
|
"""Calculate the vector inner product between two lists of Tensors."""
|
||||||
|
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
|
||||||
|
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
|
||||||
|
name=None):
|
||||||
|
"""Take an arbitrary Runge-Kutta step and estimate error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to evaluate like `func(y, t)` to compute the time derivative
|
||||||
|
of `y`.
|
||||||
|
y0: Tensor initial value for the state.
|
||||||
|
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
|
||||||
|
t0: float64 scalar Tensor giving the initial time.
|
||||||
|
dt: float64 scalar Tensor giving the size of the desired time step.
|
||||||
|
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
||||||
|
step.
|
||||||
|
name: optional name for the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
||||||
|
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
||||||
|
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
||||||
|
calculating these terms.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
|
||||||
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
|
f0 = ops.convert_to_tensor(f0, name='f0')
|
||||||
|
t0 = ops.convert_to_tensor(t0, name='t0')
|
||||||
|
dt = ops.convert_to_tensor(dt, name='dt')
|
||||||
|
dt_cast = math_ops.cast(dt, y0.dtype)
|
||||||
|
|
||||||
|
k = [f0]
|
||||||
|
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
||||||
|
ti = t0 + alpha_i * dt
|
||||||
|
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
|
||||||
|
k.append(func(yi, ti))
|
||||||
|
|
||||||
|
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
|
||||||
|
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
||||||
|
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
|
||||||
|
|
||||||
|
y1 = array_ops.identity(yi, name='%s/y1' % scope)
|
||||||
|
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
|
||||||
|
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
|
||||||
|
name='%s/y1_error' % scope)
|
||||||
|
return (y1, f1, y1_error, k)
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
||||||
|
"""Fit coefficients for 4th order polynomial interpolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y0: function value at the start of the interval.
|
||||||
|
y1: function value at the end of the interval.
|
||||||
|
y_mid: function value at the mid-point of the interval.
|
||||||
|
f0: derivative value at the start of the interval.
|
||||||
|
f1: derivative value at the end of the interval.
|
||||||
|
dt: width of the interval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
||||||
|
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
||||||
|
between 0 (start of interval) and 1 (end of interval).
|
||||||
|
"""
|
||||||
|
# a, b, c, d, e = sympy.symbols('a b c d e')
|
||||||
|
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
|
||||||
|
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
||||||
|
# sympy.solve([p.subs(x, 0) - y0,
|
||||||
|
# p.subs(x, 1 / 2) - y_mid,
|
||||||
|
# p.subs(x, 1) - y1,
|
||||||
|
# (p.diff(x) / dt).subs(x, 0) - f0,
|
||||||
|
# (p.diff(x) / dt).subs(x, 1) - f1],
|
||||||
|
# [a, b, c, d, e])
|
||||||
|
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
|
||||||
|
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
|
||||||
|
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
|
||||||
|
# d: dt*f0,
|
||||||
|
# e: y0}
|
||||||
|
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
|
||||||
|
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
|
||||||
|
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
|
||||||
|
d = dt * f0
|
||||||
|
e = y0
|
||||||
|
return [a, b, c, d, e]
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
|
||||||
|
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
||||||
|
with ops.name_scope('interp_fit_rk'):
|
||||||
|
dt = math_ops.cast(dt, y0.dtype)
|
||||||
|
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
|
||||||
|
f0 = k[0]
|
||||||
|
f1 = k[-1]
|
||||||
|
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_evaluate(coefficients, t0, t1, t):
|
||||||
|
"""Evaluate polynomial interpolation at the given time point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
||||||
|
t0: scalar float64 Tensor giving the start of the interval.
|
||||||
|
t1: scalar float64 Tensor giving the end of the interval.
|
||||||
|
t: scalar float64 Tensor giving the desired interpolation point.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Polynomial interpolation of the coefficients at time `t`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope('interp_evaluate'):
|
||||||
|
t0 = ops.convert_to_tensor(t0)
|
||||||
|
t1 = ops.convert_to_tensor(t1)
|
||||||
|
t = ops.convert_to_tensor(t)
|
||||||
|
|
||||||
|
dtype = coefficients[0].dtype
|
||||||
|
|
||||||
|
assert_op = control_flow_ops.Assert(
|
||||||
|
(t0 <= t) & (t <= t1),
|
||||||
|
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
|
||||||
|
with ops.control_dependencies([assert_op]):
|
||||||
|
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
|
||||||
|
|
||||||
|
xs = [constant_op.constant(1, dtype), x]
|
||||||
|
for _ in range(2, len(coefficients)):
|
||||||
|
xs.append(xs[-1] * x)
|
||||||
|
|
||||||
|
return _dot_product(coefficients, reversed(xs))
|
||||||
|
|
||||||
|
|
||||||
|
def _optimal_step_size(last_step,
|
||||||
|
error_ratio,
|
||||||
|
safety=0.9,
|
||||||
|
ifactor=10.0,
|
||||||
|
dfactor=0.2,
|
||||||
|
order=5,
|
||||||
|
name=None):
|
||||||
|
"""Calculate the optimal size for the next Runge-Kutta step."""
|
||||||
|
with ops.name_scope(
|
||||||
|
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
|
||||||
|
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
|
||||||
|
exponent = math_ops.cast(1 / order, last_step.dtype)
|
||||||
|
# this looks more complex than necessary, but importantly it keeps
|
||||||
|
# error_ratio in the numerator so we can't divide by zero:
|
||||||
|
factor = math_ops.maximum(
|
||||||
|
1 / ifactor,
|
||||||
|
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
|
||||||
|
return math_ops.div(last_step, factor, name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _abs_square(x):
|
||||||
|
if x.dtype.is_complex:
|
||||||
|
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
|
||||||
|
else:
|
||||||
|
return math_ops.square(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _ta_append(tensor_array, value):
|
||||||
|
"""Append a value to the end of a tf.TensorArray."""
|
||||||
|
return tensor_array.write(tensor_array.size(), value)
|
||||||
|
|
||||||
|
|
||||||
|
class _RungeKuttaState(collections.namedtuple(
|
||||||
|
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
||||||
|
"""Saved state of the Runge Kutta solver.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
y1: Tensor giving the function value at the end of the last time step.
|
||||||
|
f1: Tensor giving derivative at the end of the last time step.
|
||||||
|
t0: scalar float64 Tensor giving start of the last time step.
|
||||||
|
t1: scalar float64 Tensor giving end of the last time step.
|
||||||
|
dt: scalar float64 Tensor giving the size for the next time step.
|
||||||
|
interp_coef: list of Tensors giving coefficients for polynomial
|
||||||
|
interpolation between `t0` and `t1`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _History(collections.namedtuple(
|
||||||
|
'_History', 'integrate_points, error_ratio')):
|
||||||
|
"""Saved integration history for use in `info_dict`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
integrate_points: tf.TensorArray storing integrating time points.
|
||||||
|
error_ratio: tf.TensorArray storing computed error ratios at each
|
||||||
|
integration step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _dopri5(func,
|
||||||
|
y0,
|
||||||
|
t,
|
||||||
|
rtol,
|
||||||
|
atol,
|
||||||
|
full_output=False,
|
||||||
|
first_step=None,
|
||||||
|
safety=0.9,
|
||||||
|
ifactor=10.0,
|
||||||
|
dfactor=0.2,
|
||||||
|
max_num_steps=1000,
|
||||||
|
name=None):
|
||||||
|
"""Solve an ODE for `odeint` using method='dopri5'."""
|
||||||
|
|
||||||
|
if first_step is None:
|
||||||
|
# at some point, we might want to switch to picking the step size
|
||||||
|
# automatically
|
||||||
|
first_step = 1.0
|
||||||
|
|
||||||
|
with ops.name_scope(
|
||||||
|
name, 'dopri5',
|
||||||
|
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
|
||||||
|
|
||||||
|
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
|
||||||
|
name='first_step')
|
||||||
|
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
|
||||||
|
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
|
||||||
|
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
|
||||||
|
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
|
||||||
|
name='max_num_steps')
|
||||||
|
|
||||||
|
def adaptive_runge_kutta_step(rk_state, history, n_steps):
|
||||||
|
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
|
||||||
|
y0, f0, _, t0, dt, interp_coeff = rk_state
|
||||||
|
with ops.name_scope('assertions'):
|
||||||
|
check_underflow = control_flow_ops.Assert(
|
||||||
|
t0 + dt > t0, ['underflow in dt', dt])
|
||||||
|
check_max_num_steps = control_flow_ops.Assert(
|
||||||
|
n_steps < max_num_steps, ['max_num_steps exceeded'])
|
||||||
|
check_numerics = control_flow_ops.Assert(
|
||||||
|
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
|
||||||
|
['non-finite values in state `y`', y0])
|
||||||
|
with ops.control_dependencies(
|
||||||
|
[check_underflow, check_max_num_steps, check_numerics]):
|
||||||
|
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
|
||||||
|
|
||||||
|
with ops.name_scope('error_ratio'):
|
||||||
|
# We use the same approach as the dopri5 fortran code.
|
||||||
|
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
|
||||||
|
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
|
||||||
|
# Could also use reduce_maximum here.
|
||||||
|
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
|
||||||
|
accept_step = error_ratio <= 1
|
||||||
|
|
||||||
|
with ops.name_scope('update/rk_state'):
|
||||||
|
# If we don't accept the step, the _RungeKuttaState will be useless
|
||||||
|
# (covering a time-interval of size 0), but that's OK, because in such
|
||||||
|
# cases we always immediately take another Runge-Kutta step.
|
||||||
|
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
|
||||||
|
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
|
||||||
|
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
|
||||||
|
interp_coeff = control_flow_ops.cond(
|
||||||
|
accept_step,
|
||||||
|
lambda: _interp_fit_rk(y0, y1, k, dt),
|
||||||
|
lambda: interp_coeff)
|
||||||
|
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
|
||||||
|
rk_state = _RungeKuttaState(
|
||||||
|
y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
||||||
|
|
||||||
|
with ops.name_scope('update/history'):
|
||||||
|
history = _History(_ta_append(history.integrate_points, t0 + dt),
|
||||||
|
_ta_append(history.error_ratio, error_ratio))
|
||||||
|
return rk_state, history, n_steps + 1
|
||||||
|
|
||||||
|
def interpolate(solution, history, rk_state, i):
|
||||||
|
"""Interpolate through the next time point, integrating as necessary."""
|
||||||
|
with ops.name_scope('interpolate'):
|
||||||
|
rk_state, history, _ = control_flow_ops.while_loop(
|
||||||
|
lambda rk_state, *_: t[i] > rk_state.t1,
|
||||||
|
adaptive_runge_kutta_step,
|
||||||
|
(rk_state, history, 0),
|
||||||
|
name='integrate_loop')
|
||||||
|
y = _interp_evaluate(
|
||||||
|
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
|
||||||
|
solution = solution.write(i, y)
|
||||||
|
return solution, history, rk_state, i + 1
|
||||||
|
|
||||||
|
assert_increasing = control_flow_ops.Assert(
|
||||||
|
math_ops.reduce_all(t[1:] > t[:-1]),
|
||||||
|
['`t` must be monotonic increasing'])
|
||||||
|
with ops.control_dependencies([assert_increasing]):
|
||||||
|
num_times = array_ops.size(t)
|
||||||
|
|
||||||
|
solution = tensor_array_ops.TensorArray(
|
||||||
|
y0.dtype, size=num_times).write(0, y0)
|
||||||
|
history = _History(
|
||||||
|
integrate_points=tensor_array_ops.TensorArray(
|
||||||
|
t.dtype, size=0, dynamic_size=True),
|
||||||
|
error_ratio=tensor_array_ops.TensorArray(
|
||||||
|
rtol.dtype, size=0, dynamic_size=True))
|
||||||
|
rk_state = _RungeKuttaState(
|
||||||
|
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
|
||||||
|
|
||||||
|
solution, history, _, _ = control_flow_ops.while_loop(
|
||||||
|
lambda _, __, ___, i: i < num_times,
|
||||||
|
interpolate,
|
||||||
|
(solution, history, rk_state, 1),
|
||||||
|
name='interpolate_loop')
|
||||||
|
|
||||||
|
y = solution.pack(name=scope)
|
||||||
|
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
|
||||||
|
if not full_output:
|
||||||
|
return y
|
||||||
|
else:
|
||||||
|
integrate_points = history.integrate_points.pack()
|
||||||
|
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
|
||||||
|
'integrate_points': integrate_points,
|
||||||
|
'error_ratio': history.error_ratio.pack()}
|
||||||
|
return (y, info_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def odeint(func,
|
||||||
|
y0,
|
||||||
|
t,
|
||||||
|
rtol=1e-6,
|
||||||
|
atol=1e-12,
|
||||||
|
method=None,
|
||||||
|
options=None,
|
||||||
|
full_output=False,
|
||||||
|
name=None):
|
||||||
|
"""Integrate a system of ordinary differential equations.
|
||||||
|
|
||||||
|
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||||
|
|
||||||
|
```
|
||||||
|
dy/dt = func(y, t), y(t[0]) = y0
|
||||||
|
```
|
||||||
|
|
||||||
|
where y is a Tensor of any shape.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||||
|
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||||
|
=> [1, exp(-1), exp(-2)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||||
|
`y0` and `t`.
|
||||||
|
|
||||||
|
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||||
|
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||||
|
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||||
|
|
||||||
|
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||||
|
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||||
|
doi:10.2307/2008219
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||||
|
`t` into a Tensor of state derivatives with respect to time.
|
||||||
|
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||||
|
have any floating point or complex dtype.
|
||||||
|
t: 1-D Tensor holding a sequence of time points for which to solve for
|
||||||
|
`y`. The initial time point should be the first element of this sequence,
|
||||||
|
and each time must be larger than the previous time. May have any floating
|
||||||
|
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||||
|
float64 dtype.
|
||||||
|
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
||||||
|
per element of `y`.
|
||||||
|
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
||||||
|
per element of `y`.
|
||||||
|
method: optional string indicating the integration method to use. Currently,
|
||||||
|
the only valid option is `'dopri5'`.
|
||||||
|
options: optional dict of configuring options for the indicated integration
|
||||||
|
method. Can only be provided if a `method` is explicitly set. For
|
||||||
|
`'dopri5'`, valid options include:
|
||||||
|
* first_step: an initial guess for the size of the first integration
|
||||||
|
(current default: 1.0, but may later be changed to use heuristics based
|
||||||
|
on the gradient).
|
||||||
|
* safety: safety factor for adaptive step control, generally a constant
|
||||||
|
in the range 0.8-1 (default: 0.9).
|
||||||
|
* ifactor: maximum factor by which the adaptive step may be increased
|
||||||
|
(default: 10.0).
|
||||||
|
* dfactor: maximum factor by which the adpative step may be decreased
|
||||||
|
(default: 0.2).
|
||||||
|
* max_num_steps: integer maximum number of integrate steps between time
|
||||||
|
points in `t` (default: 1000).
|
||||||
|
full_output: optional boolean. If True, `odeint` returns a tuple
|
||||||
|
`(y, info_dict)` describing the integration process.
|
||||||
|
name: Optional name for this operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y: (N+1)-D tensor, where the first dimension corresponds to different
|
||||||
|
time points. Contains the solved value of y for each desired time point in
|
||||||
|
`t`, with the initial value `y0` being the first element along the first
|
||||||
|
dimension.
|
||||||
|
info_dict: only if `full_output == True`. A dict with the following values:
|
||||||
|
* num_func_evals: integer Tensor counting the number of function
|
||||||
|
evaluations.
|
||||||
|
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||||
|
integration time step.
|
||||||
|
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||||
|
error to the error tolerance at each integration step. An ratio greater
|
||||||
|
than 1 corresponds to rejected steps.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if an invalid `method` is provided.
|
||||||
|
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||||
|
an invalid dtype.
|
||||||
|
"""
|
||||||
|
if method is not None and method != 'dopri5':
|
||||||
|
raise ValueError('invalid method: %r' % method)
|
||||||
|
|
||||||
|
if options is None:
|
||||||
|
options = {}
|
||||||
|
elif method is None:
|
||||||
|
raise ValueError('cannot supply `options` without specifying `method`')
|
||||||
|
|
||||||
|
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
|
||||||
|
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
|
||||||
|
# arbitrarily nested tuple. This will help performance and usability by
|
||||||
|
# avoiding the need to pack/unpack in user functions.
|
||||||
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
|
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||||
|
raise TypeError('`y0` must have a floating point or complex floating '
|
||||||
|
'point dtype')
|
||||||
|
|
||||||
|
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||||
|
if not t.dtype.is_floating:
|
||||||
|
raise TypeError('`t` must have a floating point dtype')
|
||||||
|
|
||||||
|
error_dtype = abs(y0).dtype
|
||||||
|
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||||
|
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
|
||||||
|
|
||||||
|
return _dopri5(func, y0, t,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
full_output=full_output,
|
||||||
|
name=scope,
|
||||||
|
**options)
|
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for ODE solvers."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.integrate.python.ops import odes
|
||||||
|
|
||||||
|
|
||||||
|
class OdeIntTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(OdeIntTest, self).setUp()
|
||||||
|
# simple defaults (solution is a sin-wave)
|
||||||
|
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
|
||||||
|
self.func = lambda y, t: tf.matmul(matrix, y)
|
||||||
|
self.y0 = np.array([[1.0], [0.0]])
|
||||||
|
|
||||||
|
def test_odeint_exp(self):
|
||||||
|
# Test odeint by an exponential function:
|
||||||
|
# dy / dt = y, y(0) = 1.0.
|
||||||
|
# Its analytical solution is y = exp(t).
|
||||||
|
func = lambda y, t: y
|
||||||
|
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||||
|
self.assertIn('odeint', y_solved.name)
|
||||||
|
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = np.exp(t)
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_complex(self):
|
||||||
|
# Test a complex, linear ODE:
|
||||||
|
# dy / dt = k * y, y(0) = 1.0.
|
||||||
|
# Its analytical solution is y = exp(k * t).
|
||||||
|
k = 1j - 0.1
|
||||||
|
func = lambda y, t: k * y
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = np.exp(k * t)
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_riccati(self):
|
||||||
|
# The Ricatti equation is:
|
||||||
|
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
|
||||||
|
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
|
||||||
|
func = lambda t, y: (y - t)**2 + 1.0
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = 1.0 / (2.0 - t) + t
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_2d_linear(self):
|
||||||
|
# Solve the 2D linear differential equation:
|
||||||
|
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
|
||||||
|
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
|
||||||
|
# y1(0) = 0.0,
|
||||||
|
# y2(0) = 1.0.
|
||||||
|
# Its analytical solution is
|
||||||
|
# y1 = sin(4.0 * t) * exp(3.0 * t),
|
||||||
|
# y2 = cos(4.0 * t) * exp(3.0 * t).
|
||||||
|
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
|
||||||
|
func = lambda y, t: tf.matmul(matrix, y)
|
||||||
|
|
||||||
|
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
|
||||||
|
y_true = np.zeros((len(t), 2, 1))
|
||||||
|
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
|
||||||
|
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
|
||||||
|
self.assertAllClose(y_true, y_solved, atol=1e-5)
|
||||||
|
|
||||||
|
def test_odeint_higher_rank(self):
|
||||||
|
func = lambda y, t: y
|
||||||
|
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
for shape in [(), (1,), (1, 1)]:
|
||||||
|
expected_shape = (len(t),) + shape
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
|
||||||
|
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
self.assertEquals(y_solved.shape, expected_shape)
|
||||||
|
|
||||||
|
def test_odeint_all_dtypes(self):
|
||||||
|
func = lambda y, t: y
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
|
||||||
|
for t_dtype in [tf.float32, tf.float64]:
|
||||||
|
y0 = tf.cast(1.0, y0_dtype)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
expected = np.asarray(np.exp(t))
|
||||||
|
self.assertAllClose(y_solved, expected, rtol=1e-5)
|
||||||
|
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
|
||||||
|
|
||||||
|
def test_odeint_required_dtypes(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
|
||||||
|
|
||||||
|
def test_odeint_runtime_errors(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'cannot supply `options` without'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
|
||||||
|
options={'first_step': 1.0})
|
||||||
|
|
||||||
|
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
|
||||||
|
options={'max_num_steps': 0})
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
tf.errors.InvalidArgumentError, 'max_num_steps'):
|
||||||
|
sess.run(y)
|
||||||
|
|
||||||
|
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
tf.errors.InvalidArgumentError, 'monotonic increasing'):
|
||||||
|
sess.run(y)
|
||||||
|
|
||||||
|
def test_odeint_different_times(self):
|
||||||
|
# integrate steps should be independent of interpolation times
|
||||||
|
times0 = np.linspace(0, 10, num=11, dtype=float)
|
||||||
|
times1 = np.linspace(0, 10, num=101, dtype=float)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved_0, info_0 = sess.run(
|
||||||
|
tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, times0, full_output=True))
|
||||||
|
y_solved_1, info_1 = sess.run(
|
||||||
|
tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, times1, full_output=True))
|
||||||
|
|
||||||
|
self.assertAllClose(y_solved_0, y_solved_1[::10])
|
||||||
|
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
|
||||||
|
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
|
||||||
|
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
|
||||||
|
|
||||||
|
def test_odeint_5th_order_accuracy(self):
|
||||||
|
t = [0, 20]
|
||||||
|
kwargs = dict(full_output=True,
|
||||||
|
method='dopri5',
|
||||||
|
options=dict(max_num_steps=2000))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
_, info_0 = sess.run(tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
|
||||||
|
_, info_1 = sess.run(tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
|
||||||
|
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
|
||||||
|
float(info_1['integrate_points'].size),
|
||||||
|
rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class StepSizeTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_error_ratio_one(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(1.0))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 0.9)
|
||||||
|
|
||||||
|
def test_ifactor(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(0.0))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 10.0)
|
||||||
|
|
||||||
|
def test_dfactor(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(1e6))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 0.2)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolationTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_5th_order_polynomial(self):
|
||||||
|
# this should be an exact fit
|
||||||
|
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
|
||||||
|
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
|
||||||
|
coeffs = odes._interp_fit(
|
||||||
|
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
|
||||||
|
times = np.linspace(0, 10, dtype=np.float32)
|
||||||
|
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
|
||||||
|
for t in times])
|
||||||
|
y_expected = f(times)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_actual = sess.run(y_fit)
|
||||||
|
self.assertAllClose(y_expected, y_actual)
|
||||||
|
|
||||||
|
# attempt interpolation outside bounds
|
||||||
|
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
|
sess.run(y_invalid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -258,10 +258,11 @@ def optimize_loss(loss,
|
|||||||
grad_values = gradient
|
grad_values = gradient
|
||||||
|
|
||||||
if grad_values is not None:
|
if grad_values is not None:
|
||||||
|
var_name = variable.name.replace(":", "_")
|
||||||
if "gradients" in summaries:
|
if "gradients" in summaries:
|
||||||
summary.histogram("gradients/" + variable.name, grad_values)
|
summary.histogram("gradients/%s" % var_name, grad_values)
|
||||||
if "gradient_norm" in summaries:
|
if "gradient_norm" in summaries:
|
||||||
summary.scalar("gradient_norm/" + variable.name,
|
summary.scalar("gradient_norm/%s" % var_name,
|
||||||
clip_ops.global_norm([grad_values]))
|
clip_ops.global_norm([grad_values]))
|
||||||
|
|
||||||
if clip_gradients is not None and "gradient_norm" in summaries:
|
if clip_gradients is not None and "gradient_norm" in summaries:
|
||||||
|
@ -291,7 +291,9 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":learn",
|
":learn",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:test_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.learn.python.learn.utils import export
|
from tensorflow.contrib.learn.python.learn.utils import export
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -747,13 +748,16 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Numpy array of predicted classes (or an iterable of predicted classes if
|
Numpy array of predicted classes (or an iterable of predicted classes if
|
||||||
as_iterable is True).
|
as_iterable is True).
|
||||||
"""
|
"""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[head_lib.PredictionKey.CLASSES],
|
x=x,
|
||||||
as_iterable=as_iterable)
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.CLASSES].reshape(-1)
|
return preds[key].reshape(-1)
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -775,20 +779,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Numpy array of predicted probabilities (or an iterable of predicted
|
Numpy array of predicted probabilities (or an iterable of predicted
|
||||||
probabilities if as_iterable is True).
|
probabilities if as_iterable is True).
|
||||||
"""
|
"""
|
||||||
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
preds = self._estimator.predict(
|
preds = self._estimator.predict(
|
||||||
x=x, input_fn=input_fn,
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[head_lib.PredictionKey.PROBABILITIES],
|
outputs=[key],
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
return preds[key]
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
"""See `Estimator` class."""
|
"""See `Estimator` class."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return self._estimator._get_predict_ops(features)[
|
return self._estimator._get_predict_ops(features)[
|
||||||
head_lib.PredictionKey.PROBABILITIES]
|
prediction_key.PredictionKey.PROBABILITIES]
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
"""Returns list of all variable names in this model.
|
"""Returns list of all variable names in this model.
|
||||||
@ -826,9 +832,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
input_fn=input_fn or default_input_fn,
|
input_fn=input_fn or default_input_fn,
|
||||||
input_feature_key=input_feature_key,
|
input_feature_key=input_feature_key,
|
||||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
signature_fn=(
|
signature_fn=(signature_fn or
|
||||||
signature_fn or export.classification_signature_fn_with_prob),
|
export.classification_signature_fn_with_prob),
|
||||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||||
default_batch_size=default_batch_size,
|
default_batch_size=default_batch_size,
|
||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
@ -1041,10 +1047,11 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
|||||||
head=head,
|
head=head,
|
||||||
config=config,
|
config=config,
|
||||||
feature_engineering_fn=feature_engineering_fn,
|
feature_engineering_fn=feature_engineering_fn,
|
||||||
default_prediction_key=head_lib.PredictionKey.SCORES,
|
default_prediction_key=prediction_key.PredictionKey.SCORES,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
|
return super(
|
||||||
head_lib.PredictionKey.SCORES]
|
DNNLinearCombinedRegressor,
|
||||||
|
self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]
|
||||||
|
@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
|
|||||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||||
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
|
|||||||
|
|
||||||
result = _make_metrics_ops(all_metrics, features, labels,
|
result = _make_metrics_ops(all_metrics, features, labels,
|
||||||
model_fn_ops.predictions)
|
model_fn_ops.predictions)
|
||||||
if 'loss' not in result:
|
if metric_key.MetricKey.LOSS not in result:
|
||||||
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
|
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
|
||||||
|
model_fn_ops.loss)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
|
@ -24,6 +24,8 @@ from tensorflow.contrib import losses
|
|||||||
from tensorflow.contrib import metrics as metrics_lib
|
from tensorflow.contrib import metrics as metrics_lib
|
||||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.session_bundle import exporter
|
from tensorflow.contrib.session_bundle import exporter
|
||||||
from tensorflow.python import summary
|
from tensorflow.python import summary
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -388,17 +390,17 @@ class _RegressionHead(_Head):
|
|||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {}
|
predictions = {}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.SCORES] = array_ops.squeeze(
|
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
|
||||||
logits, squeeze_dims=[1])
|
logits, squeeze_dims=[1])
|
||||||
else:
|
else:
|
||||||
predictions[PredictionKey.SCORES] = logits
|
predictions[prediction_key.PredictionKey.SCORES] = logits
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
def _create_signature_fn(self):
|
def _create_signature_fn(self):
|
||||||
def _regression_signature_fn(examples, unused_features, predictions):
|
def _regression_signature_fn(examples, unused_features, predictions):
|
||||||
if isinstance(predictions, dict):
|
if isinstance(predictions, dict):
|
||||||
score = predictions[PredictionKey.SCORES]
|
score = predictions[prediction_key.PredictionKey.SCORES]
|
||||||
else:
|
else:
|
||||||
score = predictions
|
score = predictions
|
||||||
|
|
||||||
@ -409,11 +411,12 @@ class _RegressionHead(_Head):
|
|||||||
return _regression_signature_fn
|
return _regression_signature_fn
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
return {_head_prefixed(self._head_name, MetricKey.LOSS):
|
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.SCORES,
|
self._eval_loss_fn,
|
||||||
self._label_name,
|
prediction_key.PredictionKey.SCORES,
|
||||||
self._weight_column_name)}
|
self._label_name,
|
||||||
|
self._weight_column_name)}
|
||||||
|
|
||||||
|
|
||||||
class _MultiClassHead(_Head):
|
class _MultiClassHead(_Head):
|
||||||
@ -530,12 +533,16 @@ class _MultiClassHead(_Head):
|
|||||||
return self._logits_to_prediction(logits)
|
return self._logits_to_prediction(logits)
|
||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {PredictionKey.LOGITS: logits}
|
# pylint: disable=missing-docstring
|
||||||
|
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||||
|
logits)
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
|
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
logits)
|
||||||
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||||
|
logits, 1)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -546,8 +553,9 @@ class _MultiClassHead(_Head):
|
|||||||
if isinstance(predictions, dict):
|
if isinstance(predictions, dict):
|
||||||
default_signature = exporter.classification_signature(
|
default_signature = exporter.classification_signature(
|
||||||
input_tensor=examples,
|
input_tensor=examples,
|
||||||
classes_tensor=predictions[PredictionKey.CLASSES],
|
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
|
||||||
scores_tensor=predictions[PredictionKey.PROBABILITIES])
|
scores_tensor=predictions[
|
||||||
|
prediction_key.PredictionKey.PROBABILITIES])
|
||||||
else:
|
else:
|
||||||
default_signature = exporter.classification_signature(
|
default_signature = exporter.classification_signature(
|
||||||
input_tensor=examples,
|
input_tensor=examples,
|
||||||
@ -558,44 +566,49 @@ class _MultiClassHead(_Head):
|
|||||||
return _classification_signature_fn
|
return _classification_signature_fn
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.LOGITS,
|
self._eval_loss_fn,
|
||||||
self._label_name,
|
prediction_key.PredictionKey.LOGITS,
|
||||||
self._weight_column_name)}
|
self._label_name,
|
||||||
|
self._weight_column_name)}
|
||||||
|
|
||||||
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
||||||
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
||||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||||
PredictionKey.CLASSES, self._label_name,
|
prediction_key.PredictionKey.CLASSES,
|
||||||
|
self._label_name,
|
||||||
self._weight_column_name))
|
self._weight_column_name))
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
def _add_binary_metric(metric_key, metric_fn):
|
def _add_binary_metric(key, metric_fn):
|
||||||
metrics[_head_prefixed(self._head_name, metric_key)] = (
|
metrics[_head_prefixed(self._head_name, key)] = (
|
||||||
metric_spec.MetricSpec(metric_fn,
|
metric_spec.MetricSpec(metric_fn,
|
||||||
PredictionKey.LOGISTIC,
|
prediction_key.PredictionKey.LOGISTIC,
|
||||||
self._label_name,
|
self._label_name,
|
||||||
self._weight_column_name))
|
self._weight_column_name))
|
||||||
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
_add_binary_metric(
|
||||||
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||||
|
_add_binary_metric(
|
||||||
|
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||||
|
|
||||||
# Also include the streaming mean of the label as an accuracy baseline, as
|
# Also include the streaming mean of the label as an accuracy baseline, as
|
||||||
# a reminder to users.
|
# a reminder to users.
|
||||||
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
_add_binary_metric(
|
||||||
|
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||||
|
|
||||||
_add_binary_metric(MetricKey.AUC, _streaming_auc)
|
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
|
||||||
|
|
||||||
for threshold in self._thresholds:
|
for threshold in self._thresholds:
|
||||||
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
|
||||||
_accuracy_at_threshold(threshold))
|
_accuracy_at_threshold(threshold))
|
||||||
# Precision for positive examples.
|
# Precision for positive examples.
|
||||||
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
|
||||||
_streaming_at_threshold(
|
_streaming_at_threshold(
|
||||||
metrics_lib.streaming_precision_at_thresholds,
|
metrics_lib.streaming_precision_at_thresholds,
|
||||||
threshold),)
|
threshold),)
|
||||||
# Recall for positive examples.
|
# Recall for positive examples.
|
||||||
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
|
||||||
_streaming_at_threshold(
|
_streaming_at_threshold(
|
||||||
metrics_lib.streaming_recall_at_thresholds,
|
metrics_lib.streaming_recall_at_thresholds,
|
||||||
threshold))
|
threshold))
|
||||||
@ -635,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
|
|||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {}
|
predictions = {}
|
||||||
predictions[PredictionKey.LOGITS] = logits
|
predictions[prediction_key.PredictionKey.LOGITS] = logits
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||||
|
logits, 1)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.LOGITS,
|
self._eval_loss_fn,
|
||||||
self._label_name,
|
prediction_key.PredictionKey.LOGITS,
|
||||||
self._weight_column_name)}
|
self._label_name,
|
||||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
self._weight_column_name)}
|
||||||
|
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||||
PredictionKey.CLASSES, self._label_name,
|
prediction_key.PredictionKey.CLASSES,
|
||||||
|
self._label_name,
|
||||||
self._weight_column_name))
|
self._weight_column_name))
|
||||||
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
||||||
return metrics
|
return metrics
|
||||||
@ -674,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
|
|||||||
thresholds=thresholds)
|
thresholds=thresholds)
|
||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {PredictionKey.LOGITS: logits}
|
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||||
|
logits)
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
|
logits)
|
||||||
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
|
||||||
math_ops.greater(logits, 0))
|
math_ops.greater(logits, 0))
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -849,23 +867,3 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
|
|||||||
return array_ops.squeeze(precision_tensor), update_op
|
return array_ops.squeeze(precision_tensor), update_op
|
||||||
|
|
||||||
return _streaming_metrics
|
return _streaming_metrics
|
||||||
|
|
||||||
|
|
||||||
class PredictionKey(object):
|
|
||||||
CLASSES = "classes"
|
|
||||||
PROBABILITIES = "probabilities"
|
|
||||||
LOGITS = "logits"
|
|
||||||
LOGISTIC = "logistic"
|
|
||||||
SCORES = "scores"
|
|
||||||
|
|
||||||
|
|
||||||
class MetricKey(object):
|
|
||||||
LOSS = "loss"
|
|
||||||
AUC = "auc"
|
|
||||||
PREDICTION_MEAN = "labels/prediction_mean"
|
|
||||||
LABEL_MEAN = "labels/actual_label_mean"
|
|
||||||
ACCURACY = "accuracy"
|
|
||||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
|
||||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
|
||||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
|
||||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib.learn.python.learn import evaluable
|
|||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.learn.python.learn.utils import export
|
from tensorflow.contrib.learn.python.learn.utils import export
|
||||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -267,21 +268,18 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
education = sparse_column_with_hash_bucket(column_name="education",
|
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||||
hash_bucket_size=1000)
|
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
|
||||||
hash_bucket_size=1000)
|
|
||||||
|
|
||||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
hash_bucket_size=10000)
|
|
||||||
|
|
||||||
# Estimator using the default optimizer.
|
# Estimator using the default optimizer.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation])
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||||
|
|
||||||
# Or estimator using the FTRL optimizer with regularization.
|
# Or estimator using the FTRL optimizer with regularization.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation],
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||||
optimizer=tf.train.FtrlOptimizer(
|
optimizer=tf.train.FtrlOptimizer(
|
||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
l1_regularization_strength=0.001
|
l1_regularization_strength=0.001
|
||||||
@ -289,7 +287,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
|
|
||||||
# Or estimator using the SDCAOptimizer.
|
# Or estimator using the SDCAOptimizer.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation],
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||||
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
||||||
example_id_column='example_id',
|
example_id_column='example_id',
|
||||||
num_loss_partitions=...,
|
num_loss_partitions=...,
|
||||||
@ -465,13 +463,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[head_lib.PredictionKey.CLASSES],
|
x=x,
|
||||||
as_iterable=as_iterable)
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.CLASSES]
|
return preds[key]
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -479,14 +480,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||||
as_iterable=True):
|
as_iterable=True):
|
||||||
"""Runs inference to determine the class probability predictions."""
|
"""Runs inference to determine the class probability predictions."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[
|
x=x,
|
||||||
head_lib.PredictionKey.PROBABILITIES],
|
input_fn=input_fn,
|
||||||
as_iterable=as_iterable)
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
return preds[key]
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
return self._estimator.get_variable_names()
|
return self._estimator.get_variable_names()
|
||||||
@ -512,9 +515,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
input_fn=input_fn or default_input_fn,
|
input_fn=input_fn or default_input_fn,
|
||||||
input_feature_key=input_feature_key,
|
input_feature_key=input_feature_key,
|
||||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
signature_fn=(
|
signature_fn=(signature_fn or
|
||||||
signature_fn or export.classification_signature_fn_with_prob),
|
export.classification_signature_fn_with_prob),
|
||||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||||
default_batch_size=default_batch_size,
|
default_batch_size=default_batch_size,
|
||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
@ -561,16 +564,13 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
education = sparse_column_with_hash_bucket(column_name="education",
|
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||||
hash_bucket_size=1000)
|
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
|
||||||
hash_bucket_size=1000)
|
|
||||||
|
|
||||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
hash_bucket_size=10000)
|
|
||||||
|
|
||||||
estimator = LinearRegressor(
|
estimator = LinearRegressor(
|
||||||
feature_columns=[occupation, education_x_occupation])
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||||
|
|
||||||
# Input builders
|
# Input builders
|
||||||
def input_fn_train: # returns x, y
|
def input_fn_train: # returns x, y
|
||||||
@ -731,13 +731,16 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.SCORES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[head_lib.PredictionKey.SCORES],
|
x=x,
|
||||||
as_iterable=as_iterable)
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.SCORES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.SCORES]
|
return preds[key]
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
return self._estimator.get_variable_names()
|
return self._estimator.get_variable_names()
|
||||||
@ -764,7 +767,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
input_feature_key=input_feature_key,
|
input_feature_key=input_feature_key,
|
||||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
signature_fn=(signature_fn or export.regression_signature_fn),
|
signature_fn=(signature_fn or export.regression_signature_fn),
|
||||||
prediction_key=head_lib.PredictionKey.SCORES,
|
prediction_key=prediction_key.PredictionKey.SCORES,
|
||||||
default_batch_size=default_batch_size,
|
default_batch_size=default_batch_size,
|
||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Enum for metric keys."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
class MetricKey(object):
|
||||||
|
LOSS = "loss"
|
||||||
|
AUC = "auc"
|
||||||
|
PREDICTION_MEAN = "labels/prediction_mean"
|
||||||
|
LABEL_MEAN = "labels/actual_label_mean"
|
||||||
|
ACCURACY = "accuracy"
|
||||||
|
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||||
|
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||||
|
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||||
|
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Enum for model prediction keys."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionKey(object):
|
||||||
|
CLASSES = "classes"
|
||||||
|
PROBABILITIES = "probabilities"
|
||||||
|
LOGITS = "logits"
|
||||||
|
LOGISTIC = "logistic"
|
||||||
|
SCORES = "scores"
|
@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -188,13 +189,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[head_lib.PredictionKey.CLASSES],
|
x=x,
|
||||||
as_iterable=as_iterable)
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.CLASSES]
|
return preds[key]
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -202,14 +206,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||||
as_iterable=True):
|
as_iterable=True):
|
||||||
"""Runs inference to determine the class probability predictions."""
|
"""Runs inference to determine the class probability predictions."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
batch_size=batch_size,
|
preds = self._estimator.predict(
|
||||||
outputs=[
|
x=x,
|
||||||
head_lib.PredictionKey.PROBABILITIES],
|
input_fn=input_fn,
|
||||||
as_iterable=as_iterable)
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
return preds[key]
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
@ -77,7 +78,8 @@ def get_summary_writer(logdir):
|
|||||||
|
|
||||||
|
|
||||||
def _make_saver(graph, keep_checkpoint_max=5):
|
def _make_saver(graph, keep_checkpoint_max=5):
|
||||||
vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
|
vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
|
||||||
|
graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
||||||
if vars_to_save:
|
if vars_to_save:
|
||||||
return tf_saver.Saver(vars_to_save,
|
return tf_saver.Saver(vars_to_save,
|
||||||
sharded=True,
|
sharded=True,
|
||||||
@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
|||||||
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
||||||
|
|
||||||
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
||||||
|
|
||||||
with graph.as_default() as g:
|
with graph.as_default() as g:
|
||||||
with tf_session.Session('') as session:
|
with tf_session.Session('') as session:
|
||||||
|
session.run(
|
||||||
|
resources.initialize_resources(resources.shared_resources() +
|
||||||
|
resources.local_resources()))
|
||||||
if restore_checkpoint_path:
|
if restore_checkpoint_path:
|
||||||
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
||||||
else:
|
else:
|
||||||
|
@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
|
|||||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_ops
|
||||||
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
|
|||||||
pass
|
pass
|
||||||
self.assertTrue(request_stop.called)
|
self.assertTrue(request_stop.called)
|
||||||
|
|
||||||
|
def test_run_feeds_iter_calls_resources_init(self):
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
in0, _, _ = self._build_inference_graph()
|
||||||
|
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
|
||||||
|
resources.register_resource(
|
||||||
|
handle=handle,
|
||||||
|
create_op=test_ops.resource_create_op(handle),
|
||||||
|
is_initialized_op=test_ops.resource_initialized_op(handle))
|
||||||
|
|
||||||
|
for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
|
||||||
|
feed_dicts=[{}]):
|
||||||
|
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
|
||||||
|
|
||||||
def test_infer_different_default_graph(self):
|
def test_infer_different_default_graph(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self._assert_ckpt(self._output_dir, False)
|
self._assert_ckpt(self._output_dir, False)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
### TensorFlow Makefile
|
### TensorFlow Makefile
|
||||||
|
|
||||||
The recommended way to build TensorFlow from source is using the Bazel
|
The recommended way to build TensorFlow from source is using the Bazel
|
||||||
open-source build system. Sometimes this isn't possible.
|
open-source build system. Sometimes this isn't possible. For example,
|
||||||
|
if you are building for iOS, you currently need to use the Makefile.
|
||||||
|
|
||||||
- The build system may not have the RAM or processing power to support Bazel.
|
- The build system may not have the RAM or processing power to support Bazel.
|
||||||
- Bazel or its dependencies may not be available.
|
- Bazel or its dependencies may not be available.
|
||||||
|
@ -43,6 +43,13 @@ tensorflow/core/kernels/sequence_ops.cc
|
|||||||
tensorflow/core/kernels/sendrecv_ops.cc
|
tensorflow/core/kernels/sendrecv_ops.cc
|
||||||
tensorflow/core/kernels/scatter_op.cc
|
tensorflow/core/kernels/scatter_op.cc
|
||||||
tensorflow/core/kernels/scatter_functor.cc
|
tensorflow/core/kernels/scatter_functor.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op.cc
|
||||||
tensorflow/core/kernels/save_restore_tensor.cc
|
tensorflow/core/kernels/save_restore_tensor.cc
|
||||||
tensorflow/core/kernels/save_restore_v2_ops.cc
|
tensorflow/core/kernels/save_restore_v2_ops.cc
|
||||||
tensorflow/core/kernels/save_op.cc
|
tensorflow/core/kernels/save_op.cc
|
||||||
|
@ -763,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
|||||||
computes the area under a discretized curve of precision versus recall values
|
computes the area under a discretized curve of precision versus recall values
|
||||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||||
controls the degree of discretization with larger numbers of thresholds more
|
controls the degree of discretization with larger numbers of thresholds more
|
||||||
closely approximating the true AUC.
|
closely approximating the true AUC. The quality of the approximation may vary
|
||||||
|
dramatically depending on `num_thresholds`.
|
||||||
|
|
||||||
|
For best results, `predictions` should be distributed approximately uniformly
|
||||||
|
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||||
|
approximation may be poor if this is not the case.
|
||||||
|
|
||||||
For estimation of the metric over a stream of data, the function creates an
|
For estimation of the metric over a stream of data, the function creates an
|
||||||
`update_op` operation that updates these variables and returns the `auc`.
|
`update_op` operation that updates these variables and returns the `auc`.
|
||||||
|
@ -15,9 +15,16 @@ py_library(
|
|||||||
"python/training/resample.py",
|
"python/training/resample.py",
|
||||||
"python/training/sampling_ops.py",
|
"python/training/sampling_ops.py",
|
||||||
"python/training/sequence_queueing_state_saver.py",
|
"python/training/sequence_queueing_state_saver.py",
|
||||||
|
"python/training/training.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
@ -98,6 +105,19 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "training_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["python/training/training_test.py"],
|
||||||
|
shard_count = 3,
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":training_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
|
|||||||
from tensorflow.contrib.training.python.training.resample import *
|
from tensorflow.contrib.training.python.training.resample import *
|
||||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||||
|
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
|
||||||
|
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
|
||||||
|
from tensorflow.contrib.training.python.training.training import create_train_op
|
||||||
|
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||||
|
from tensorflow.contrib.training.python.training.training import train
|
||||||
from tensorflow.python.util.all_util import make_all
|
from tensorflow.python.util.all_util import make_all
|
||||||
|
|
||||||
__all__ = make_all(__name__)
|
__all__ = make_all(__name__)
|
||||||
|
316
tensorflow/contrib/training/python/training/training.py
Normal file
316
tensorflow/contrib/training/python/training/training.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Contains various routines and helper functions for training models.
|
||||||
|
|
||||||
|
TODO(nsilberman): Port documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
|
from tensorflow.python import summary
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import monitored_session
|
||||||
|
from tensorflow.python.training import optimizer as tf_optimizer
|
||||||
|
|
||||||
|
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
|
||||||
|
# multiply_gradients into contrib/summaries and contrib/optimizers.py
|
||||||
|
__all__ = [
|
||||||
|
'add_gradients_summaries',
|
||||||
|
'clip_gradient_norms',
|
||||||
|
'create_train_op',
|
||||||
|
'multiply_gradients',
|
||||||
|
'train',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_gradients_summaries(grads_and_vars):
|
||||||
|
"""Add summaries to gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of created summaries.
|
||||||
|
"""
|
||||||
|
summaries = []
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
if grad is not None:
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
grad_values = grad.values
|
||||||
|
else:
|
||||||
|
grad_values = grad
|
||||||
|
summaries.append(summary.histogram_summary(
|
||||||
|
var.op.name + ':gradient', grad_values))
|
||||||
|
summaries.append(summary.histogram_summary(
|
||||||
|
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
|
||||||
|
else:
|
||||||
|
logging.info('Var %s has no gradient', var.op.name)
|
||||||
|
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_norms(gradients_to_variables, max_norm):
|
||||||
|
"""Clips the gradients by the given value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients_to_variables: A list of gradient to variable pairs (tuples).
|
||||||
|
max_norm: the maximum norm value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of clipped gradient to variable pairs.
|
||||||
|
"""
|
||||||
|
clipped_grads_and_vars = []
|
||||||
|
for grad, var in gradients_to_variables:
|
||||||
|
if grad is not None:
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
|
||||||
|
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||||
|
else:
|
||||||
|
grad = clip_ops.clip_by_norm(grad, max_norm)
|
||||||
|
clipped_grads_and_vars.append((grad, var))
|
||||||
|
return clipped_grads_and_vars
|
||||||
|
|
||||||
|
|
||||||
|
def multiply_gradients(grads_and_vars, gradient_multipliers):
|
||||||
|
"""Multiply specified gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||||
|
gradient_multipliers: A map from either `Variables` or `Variable` op names
|
||||||
|
to the coefficient by which the associated gradient should be scaled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated list of gradient to variable pairs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
|
||||||
|
is empty or None or if `gradient_multipliers` is not a dictionary.
|
||||||
|
"""
|
||||||
|
if not isinstance(grads_and_vars, list):
|
||||||
|
raise ValueError('`grads_and_vars` must be a list.')
|
||||||
|
if not gradient_multipliers:
|
||||||
|
raise ValueError('`gradient_multipliers` is empty.')
|
||||||
|
if not isinstance(gradient_multipliers, dict):
|
||||||
|
raise ValueError('`gradient_multipliers` must be a dict.')
|
||||||
|
|
||||||
|
multiplied_grads_and_vars = []
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
if var in gradient_multipliers or var.op.name in gradient_multipliers:
|
||||||
|
key = var if var in gradient_multipliers else var.op.name
|
||||||
|
if grad is None:
|
||||||
|
raise ValueError('Requested multiple of `None` gradient.')
|
||||||
|
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
tmp = grad.values * constant_op.constant(
|
||||||
|
gradient_multipliers[key], dtype=grad.dtype)
|
||||||
|
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||||
|
else:
|
||||||
|
grad *= constant_op.constant(
|
||||||
|
gradient_multipliers[key], dtype=grad.dtype)
|
||||||
|
multiplied_grads_and_vars.append((grad, var))
|
||||||
|
return multiplied_grads_and_vars
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_op(total_loss,
|
||||||
|
optimizer,
|
||||||
|
global_step=None,
|
||||||
|
update_ops=None,
|
||||||
|
variables_to_train=None,
|
||||||
|
transform_grads_fn=None,
|
||||||
|
summarize_gradients=False,
|
||||||
|
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False):
|
||||||
|
"""Creates an `Operation` that evaluates the gradients and returns the loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_loss: A `Tensor` representing the total loss.
|
||||||
|
optimizer: A tf.Optimizer to use for computing the gradients.
|
||||||
|
global_step: A `Tensor` representing the global step variable. If left as
|
||||||
|
`None`, then slim.variables.global_step() is used.
|
||||||
|
update_ops: An optional list of updates to execute. If `update_ops` is
|
||||||
|
`None`, then the update ops are set to the contents of the
|
||||||
|
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
|
||||||
|
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
|
||||||
|
a warning will be displayed.
|
||||||
|
variables_to_train: an optional list of variables to train. If None, it will
|
||||||
|
default to all tf.trainable_variables().
|
||||||
|
transform_grads_fn: A function which takes a single argument, a list of
|
||||||
|
gradient to variable pairs (tuples), performs any requested gradient
|
||||||
|
updates, such as gradient clipping or multipliers, and returns the updated
|
||||||
|
list.
|
||||||
|
summarize_gradients: Whether or not add summaries for each gradient.
|
||||||
|
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
|
||||||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||||||
|
Valid values are defined in the class `AggregationMethod`.
|
||||||
|
colocate_gradients_with_ops: Whether or not to try colocating the gradients
|
||||||
|
with the ops that generated them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that when evaluated, computes the gradients and returns the total
|
||||||
|
loss value.
|
||||||
|
"""
|
||||||
|
if global_step is None:
|
||||||
|
global_step = variables.get_or_create_global_step()
|
||||||
|
|
||||||
|
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
|
||||||
|
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||||
|
if update_ops is None:
|
||||||
|
update_ops = global_update_ops
|
||||||
|
else:
|
||||||
|
update_ops = set(update_ops)
|
||||||
|
if not global_update_ops.issubset(update_ops):
|
||||||
|
logging.warning('update_ops in create_train_op does not contain all the '
|
||||||
|
' update_ops in GraphKeys.UPDATE_OPS')
|
||||||
|
|
||||||
|
# Make sure update_ops are computed before total_loss.
|
||||||
|
if update_ops:
|
||||||
|
with ops.control_dependencies(update_ops):
|
||||||
|
barrier = control_flow_ops.no_op(name='update_barrier')
|
||||||
|
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
|
||||||
|
|
||||||
|
if variables_to_train is None:
|
||||||
|
# Default to tf.trainable_variables()
|
||||||
|
variables_to_train = tf_variables.trainable_variables()
|
||||||
|
else:
|
||||||
|
# Make sure that variables_to_train are in tf.trainable_variables()
|
||||||
|
for v in variables_to_train:
|
||||||
|
assert v in tf_variables.trainable_variables()
|
||||||
|
|
||||||
|
assert variables_to_train
|
||||||
|
|
||||||
|
# Create the gradients. Note that apply_gradients adds the gradient
|
||||||
|
# computation to the current graph.
|
||||||
|
grads = optimizer.compute_gradients(
|
||||||
|
total_loss,
|
||||||
|
variables_to_train,
|
||||||
|
gate_gradients=gate_gradients,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||||||
|
|
||||||
|
if transform_grads_fn:
|
||||||
|
grads = transform_grads_fn(grads)
|
||||||
|
|
||||||
|
# Summarize gradients.
|
||||||
|
if summarize_gradients:
|
||||||
|
with ops.name_scope('summarize_grads'):
|
||||||
|
add_gradients_summaries(grads)
|
||||||
|
|
||||||
|
# Create gradient updates.
|
||||||
|
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
|
||||||
|
|
||||||
|
with ops.name_scope('train_op'):
|
||||||
|
# Make sure total_loss is valid.
|
||||||
|
total_loss = array_ops.check_numerics(total_loss,
|
||||||
|
'LossTensor is inf or nan')
|
||||||
|
|
||||||
|
# Ensure the train_tensor computes grad_updates.
|
||||||
|
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
train_op,
|
||||||
|
logdir,
|
||||||
|
master='',
|
||||||
|
is_chief=True,
|
||||||
|
scaffold=None,
|
||||||
|
hooks=None,
|
||||||
|
chief_only_hooks=None,
|
||||||
|
save_checkpoint_secs=600,
|
||||||
|
save_summaries_steps=100,
|
||||||
|
config=None):
|
||||||
|
"""Runs the training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||||
|
return the loss value.
|
||||||
|
logdir: The directory where the graph and checkpoints are saved.
|
||||||
|
master: The URL of the master.
|
||||||
|
is_chief: Specifies whether or not the training is being run by the primary
|
||||||
|
replica during replica training.
|
||||||
|
scaffold: An tf.train.Scaffold instance.
|
||||||
|
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||||
|
training loop.
|
||||||
|
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||||
|
inside the training loop for the chief trainer only.
|
||||||
|
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||||
|
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||||
|
`None`, then the default checkpoint saver isn't used.
|
||||||
|
save_summaries_steps: The frequency, in number of global steps, that the
|
||||||
|
summaries are written to disk using a default summary saver. If
|
||||||
|
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||||
|
isn't used.
|
||||||
|
config: An instance of `tf.ConfigProto`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the value of the loss function after training.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
|
||||||
|
`save_summaries_steps` are `None.
|
||||||
|
"""
|
||||||
|
# TODO(nsilberman): move this logic into monitored_session.py
|
||||||
|
scaffold = scaffold or monitored_session.Scaffold()
|
||||||
|
|
||||||
|
hooks = hooks or []
|
||||||
|
|
||||||
|
if is_chief:
|
||||||
|
session_creator = monitored_session.ChiefSessionCreator(
|
||||||
|
scaffold=scaffold,
|
||||||
|
checkpoint_dir=logdir,
|
||||||
|
master=master,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
if chief_only_hooks:
|
||||||
|
hooks.extend(chief_only_hooks)
|
||||||
|
|
||||||
|
hooks.append(basic_session_run_hooks.StepCounterHook(
|
||||||
|
output_dir=logdir))
|
||||||
|
|
||||||
|
if save_summaries_steps:
|
||||||
|
if logdir is None:
|
||||||
|
raise ValueError(
|
||||||
|
'logdir cannot be None when save_summaries_steps is None')
|
||||||
|
hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||||
|
scaffold=scaffold,
|
||||||
|
save_steps=save_summaries_steps,
|
||||||
|
output_dir=logdir))
|
||||||
|
|
||||||
|
if save_checkpoint_secs:
|
||||||
|
if logdir is None:
|
||||||
|
raise ValueError(
|
||||||
|
'logdir cannot be None when save_checkpoint_secs is None')
|
||||||
|
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||||
|
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
|
||||||
|
else:
|
||||||
|
session_creator = monitored_session.WorkerSessionCreator(
|
||||||
|
scaffold=scaffold, master=master, config=config)
|
||||||
|
|
||||||
|
with monitored_session.MonitoredSession(
|
||||||
|
session_creator=session_creator, hooks=hooks) as session:
|
||||||
|
loss = None
|
||||||
|
while not session.should_stop():
|
||||||
|
loss = session.run(train_op)
|
||||||
|
return loss
|
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf.contrib.training.training."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def logistic_classifier(inputs):
|
||||||
|
return tf.contrib.layers.fully_connected(
|
||||||
|
inputs, 1, activation_fn=tf.sigmoid)
|
||||||
|
|
||||||
|
|
||||||
|
def batchnorm_classifier(inputs):
|
||||||
|
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
|
||||||
|
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTrainOpTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
# Create an easy training set:
|
||||||
|
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
def testUseUpdateOps(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
expected_mean = np.mean(self._inputs, axis=(0))
|
||||||
|
expected_var = np.var(self._inputs, axis=(0))
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
|
||||||
|
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||||
|
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||||
|
'moving_variance')[0]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize all variables
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
mean, variance = sess.run([moving_mean, moving_variance])
|
||||||
|
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
sess.run([train_op])
|
||||||
|
mean = moving_mean.eval()
|
||||||
|
variance = moving_variance.eval()
|
||||||
|
# After 10 updates with decay 0.1 moving_mean == expected_mean and
|
||||||
|
# moving_variance == expected_var.
|
||||||
|
self.assertAllClose(mean, expected_mean)
|
||||||
|
self.assertAllClose(variance, expected_var)
|
||||||
|
|
||||||
|
def testEmptyUpdateOps(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, update_ops=[])
|
||||||
|
|
||||||
|
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||||
|
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||||
|
'moving_variance')[0]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize all variables
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
mean, variance = sess.run([moving_mean, moving_variance])
|
||||||
|
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
sess.run([train_op])
|
||||||
|
mean = moving_mean.eval()
|
||||||
|
variance = moving_variance.eval()
|
||||||
|
|
||||||
|
# Since we skip update_ops the moving_vars are not updated.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainBNClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Create an easy training set:
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
self._inputs = np.zeros((16, 4))
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
|
||||||
|
|
||||||
|
for i in range(16):
|
||||||
|
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||||
|
self._inputs[i, j] = 1
|
||||||
|
|
||||||
|
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
|
||||||
|
g = tf.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, self._logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertLess(loss, .1)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Create an easy training set:
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
self._inputs = np.zeros((16, 4))
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
for i in range(16):
|
||||||
|
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||||
|
self._inputs[i, j] = 1
|
||||||
|
|
||||||
|
def testCanAchieveZeroLoss(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testTrainWithLocalVariable(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
local_multiplier = tf.contrib.framework.local_variable(1.0)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testResumeTrainAchievesRoughlyTheSameLoss(self):
|
||||||
|
number_of_steps = [300, 1, 5]
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
|
||||||
|
|
||||||
|
for i in range(len(number_of_steps)):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(i)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
|
||||||
|
tf.train.CheckpointSaverHook(
|
||||||
|
logdir, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
|
learning_rate=learning_rate)
|
||||||
|
|
||||||
|
def transform_grads_fn(grads):
|
||||||
|
if gradient_multiplier != 1.0:
|
||||||
|
variables = tf.trainable_variables()
|
||||||
|
gradient_multipliers = {var: gradient_multiplier for var in variables}
|
||||||
|
|
||||||
|
with tf.name_scope('multiply_grads'):
|
||||||
|
return tf.contrib.training.multiply_gradients(
|
||||||
|
grads, gradient_multipliers)
|
||||||
|
else:
|
||||||
|
return grads
|
||||||
|
|
||||||
|
return tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
|
||||||
|
|
||||||
|
def testTrainWithInitFromCheckpoint(self):
|
||||||
|
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
|
||||||
|
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
|
||||||
|
|
||||||
|
if tf.gfile.Exists(logdir1): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir1)
|
||||||
|
if tf.gfile.Exists(logdir2): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir2)
|
||||||
|
|
||||||
|
# First, train the model one step (make sure the error is high).
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=1),
|
||||||
|
], save_checkpoint_secs=None)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
|
||||||
|
# Next, train the model to convergence.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(1)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=300),
|
||||||
|
], save_checkpoint_secs=None)
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .02)
|
||||||
|
|
||||||
|
# Finally, advance the model a single step and validate that the loss is
|
||||||
|
# still low.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
|
||||||
|
model_variables = tf.all_variables()
|
||||||
|
model_path = os.path.join(logdir1, 'model.ckpt-300')
|
||||||
|
|
||||||
|
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
|
||||||
|
model_path, model_variables)
|
||||||
|
def init_fn(_, session):
|
||||||
|
assign_fn(session)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op,
|
||||||
|
logdir2,
|
||||||
|
scaffold=tf.train.Scaffold(init_fn=init_fn),
|
||||||
|
hooks=[tf.train.StopAtStepHook(num_steps=1)])
|
||||||
|
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .02)
|
||||||
|
|
||||||
|
def ModelLoss(self):
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
return tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
|
||||||
|
if tf.gfile.Exists(logdir): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir)
|
||||||
|
|
||||||
|
# First, train only the weights of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
weights = tf.contrib.framework.get_variables_by_name('weights')
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss,
|
||||||
|
optimizer,
|
||||||
|
variables_to_train=weights)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=200),
|
||||||
|
])
|
||||||
|
self.assertGreater(loss, .015)
|
||||||
|
self.assertLess(loss, .05)
|
||||||
|
|
||||||
|
# Next, train the biases of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(1)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
biases = tf.contrib.framework.get_variables_by_name('biases')
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss,
|
||||||
|
optimizer,
|
||||||
|
variables_to_train=biases)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=300),
|
||||||
|
])
|
||||||
|
self.assertGreater(loss, .015)
|
||||||
|
self.assertLess(loss, .05)
|
||||||
|
|
||||||
|
# Finally, train both weights and bias to get lower loss.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=400),
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
|
||||||
|
# First, train only the weights of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
weights, biases = tf.contrib.framework.get_variables()
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
train_weights = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, variables_to_train=[weights])
|
||||||
|
train_biases = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, variables_to_train=[biases])
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize the variables.
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
|
||||||
|
# Get the intial weights and biases values.
|
||||||
|
weights_values, biases_values = sess.run([weights, biases])
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values), 0)
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
|
||||||
|
|
||||||
|
# Update weights and biases.
|
||||||
|
loss = sess.run(train_op)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the weights and biases have been updated.
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
|
||||||
|
weights_values, biases_values = new_weights, new_biases
|
||||||
|
|
||||||
|
# Update only weights.
|
||||||
|
loss = sess.run(train_weights)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the weights have been updated, but biases have not.
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
weights_values = new_weights
|
||||||
|
|
||||||
|
# Update only biases.
|
||||||
|
loss = sess.run(train_biases)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the biases have been updated, but weights have not.
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
|
||||||
|
def testTrainWithAlteredGradients(self):
|
||||||
|
# Use the same learning rate but different gradient multipliers
|
||||||
|
# to train two models. Model with equivalently larger learning
|
||||||
|
# rate (i.e., learning_rate * gradient_multiplier) has smaller
|
||||||
|
# training loss.
|
||||||
|
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
|
||||||
|
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
|
||||||
|
|
||||||
|
if tf.gfile.Exists(logdir1):
|
||||||
|
tf.gfile.DeleteRecursively(logdir1)
|
||||||
|
if tf.gfile.Exists(logdir2):
|
||||||
|
tf.gfile.DeleteRecursively(logdir2)
|
||||||
|
|
||||||
|
multipliers = [1., 1000.]
|
||||||
|
number_of_steps = 10
|
||||||
|
losses = []
|
||||||
|
learning_rate = 0.001
|
||||||
|
|
||||||
|
# First, train the model with equivalently smaller learning rate.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gradient_multiplier=multipliers[0])
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
|
||||||
|
losses.append(loss)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
|
||||||
|
# Second, train the model with equivalently larger learning rate.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gradient_multiplier=multipliers[1])
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir2, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||||
|
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
|
||||||
|
losses.append(loss)
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .5)
|
||||||
|
|
||||||
|
# The loss of the model trained with larger learning rate should
|
||||||
|
# be smaller.
|
||||||
|
self.assertGreater(losses[0], losses[1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -221,13 +221,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "jpeg",
|
|
||||||
hdrs = ["lib/jpeg/jpeg_mem.h"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [":jpeg_internal"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test support library needed for all tests
|
# Test support library needed for all tests
|
||||||
# This is currently public, but may be made internal in the
|
# This is currently public, but may be made internal in the
|
||||||
# future. Try to avoid depending on it.
|
# future. Try to avoid depending on it.
|
||||||
@ -699,9 +692,9 @@ filegroup(
|
|||||||
"platform/cuda.h",
|
"platform/cuda.h",
|
||||||
"platform/google/**/*",
|
"platform/google/**/*",
|
||||||
"platform/hadoop/**/*",
|
"platform/hadoop/**/*",
|
||||||
"platform/jpeg.*",
|
"platform/gif.h",
|
||||||
"platform/png.*",
|
"platform/jpeg.h",
|
||||||
"platform/gif.*",
|
"platform/png.h",
|
||||||
"platform/stream_executor.*",
|
"platform/stream_executor.*",
|
||||||
"platform/windows/**/*",
|
"platform/windows/**/*",
|
||||||
"user_ops/**/*.cu.cc",
|
"user_ops/**/*.cu.cc",
|
||||||
@ -981,7 +974,10 @@ cc_library(
|
|||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
"**/*test*",
|
"**/*test*",
|
||||||
|
"lib/gif/**/*",
|
||||||
"lib/jpeg/**/*",
|
"lib/jpeg/**/*",
|
||||||
|
"platform/gif.h",
|
||||||
|
"platform/jpeg.h",
|
||||||
"platform/**/cuda.h",
|
"platform/**/cuda.h",
|
||||||
"platform/**/stream_executor.h",
|
"platform/**/stream_executor.h",
|
||||||
"platform/load_library.cc",
|
"platform/load_library.cc",
|
||||||
@ -998,7 +994,10 @@ cc_library(
|
|||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
"**/*test*",
|
"**/*test*",
|
||||||
|
"lib/gif/**/*",
|
||||||
"lib/jpeg/**/*",
|
"lib/jpeg/**/*",
|
||||||
|
"platform/gif.h",
|
||||||
|
"platform/jpeg.h",
|
||||||
"platform/**/cuda.h",
|
"platform/**/cuda.h",
|
||||||
"platform/**/stream_executor.h",
|
"platform/**/stream_executor.h",
|
||||||
],
|
],
|
||||||
@ -1016,7 +1015,6 @@ cc_library(
|
|||||||
hdrs = tf_additional_lib_hdrs() + [
|
hdrs = tf_additional_lib_hdrs() + [
|
||||||
"lib/core/blocking_counter.h",
|
"lib/core/blocking_counter.h",
|
||||||
"lib/core/refcount.h",
|
"lib/core/refcount.h",
|
||||||
"lib/gif/gif_io.h",
|
|
||||||
"lib/gtl/edit_distance.h",
|
"lib/gtl/edit_distance.h",
|
||||||
"lib/gtl/int_type.h",
|
"lib/gtl/int_type.h",
|
||||||
"lib/gtl/iterator_range.h",
|
"lib/gtl/iterator_range.h",
|
||||||
@ -1060,18 +1058,32 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gif_internal",
|
||||||
|
srcs = [
|
||||||
|
"lib/gif/gif_io.cc",
|
||||||
|
"platform/gif.h",
|
||||||
|
],
|
||||||
|
hdrs = ["lib/gif/gif_io.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
linkopts = ["-ldl"],
|
||||||
|
deps = [
|
||||||
|
":lib",
|
||||||
|
"//tensorflow/core/platform/default/build_config:gif",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "jpeg_internal",
|
name = "jpeg_internal",
|
||||||
srcs = glob(
|
srcs = [
|
||||||
[
|
"lib/jpeg/jpeg_handle.cc",
|
||||||
"lib/jpeg/*h",
|
"lib/jpeg/jpeg_mem.cc",
|
||||||
"lib/jpeg/*.cc",
|
"platform/jpeg.h",
|
||||||
],
|
],
|
||||||
exclude = [
|
hdrs = [
|
||||||
"**/*test*",
|
"lib/jpeg/jpeg_handle.h",
|
||||||
],
|
"lib/jpeg/jpeg_mem.h",
|
||||||
),
|
],
|
||||||
hdrs = ["lib/jpeg/jpeg_handle.h"],
|
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
linkopts = ["-ldl"],
|
linkopts = ["-ldl"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1541,7 +1553,6 @@ cc_test(
|
|||||||
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
|
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
|
||||||
data = glob(["lib/jpeg/testdata/*.jpg"]),
|
data = glob(["lib/jpeg/testdata/*.jpg"]),
|
||||||
deps = [
|
deps = [
|
||||||
":jpeg",
|
|
||||||
":jpeg_internal",
|
":jpeg_internal",
|
||||||
":lib",
|
":lib",
|
||||||
":lib_internal",
|
":lib_internal",
|
||||||
|
@ -78,7 +78,7 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
|
|||||||
Status DeviceFactory::AddDevices(const SessionOptions& options,
|
Status DeviceFactory::AddDevices(const SessionOptions& options,
|
||||||
const string& name_prefix,
|
const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<Device*>* devices) {
|
||||||
// CPU first.
|
// CPU first. A CPU device is required.
|
||||||
auto cpu_factory = GetFactory("CPU");
|
auto cpu_factory = GetFactory("CPU");
|
||||||
if (!cpu_factory) {
|
if (!cpu_factory) {
|
||||||
return errors::NotFound(
|
return errors::NotFound(
|
||||||
@ -90,18 +90,11 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
|
|||||||
return errors::NotFound("No CPU devices are available in this process");
|
return errors::NotFound("No CPU devices are available in this process");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then GPU.
|
// Then the rest (including GPU).
|
||||||
auto gpu_factory = GetFactory("GPU");
|
|
||||||
if (gpu_factory) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
gpu_factory->CreateDevices(options, name_prefix, devices));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then the rest.
|
|
||||||
mutex_lock l(*get_device_factory_lock());
|
mutex_lock l(*get_device_factory_lock());
|
||||||
for (auto& p : device_factories()) {
|
for (auto& p : device_factories()) {
|
||||||
auto factory = p.second.factory.get();
|
auto factory = p.second.factory.get();
|
||||||
if (factory != cpu_factory && factory != gpu_factory) {
|
if (factory != cpu_factory) {
|
||||||
TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
|
TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -282,6 +282,7 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
|||||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
mu_.unlock();
|
||||||
|
|
||||||
SchedClosure([session, req, resp, done]() {
|
SchedClosure([session, req, resp, done]() {
|
||||||
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
|
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
|
||||||
@ -290,7 +291,22 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
|||||||
}
|
}
|
||||||
done(status);
|
done(status);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||||
|
PartialRunSetupResponse* resp, MyClosure done) {
|
||||||
|
mu_.lock();
|
||||||
|
MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
|
||||||
|
if (session == nullptr) {
|
||||||
|
mu_.unlock();
|
||||||
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
mu_.unlock();
|
mu_.unlock();
|
||||||
|
|
||||||
|
SchedClosure([this, session, req, resp, done]() {
|
||||||
|
done(session->PartialRunSetup(req, resp));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||||
@ -303,6 +319,7 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
|||||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
mu_.unlock();
|
||||||
|
|
||||||
SchedClosure([this, start_time, session, opts, req, resp, done]() {
|
SchedClosure([this, start_time, session, opts, req, resp, done]() {
|
||||||
Status status = session->Run(opts, req, resp);
|
Status status = session->Run(opts, req, resp);
|
||||||
@ -312,7 +329,6 @@ void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
|||||||
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
|
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
|
||||||
++step_count_;
|
++step_count_;
|
||||||
});
|
});
|
||||||
mu_.unlock();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Master::CloseSession(const CloseSessionRequest* req,
|
void Master::CloseSession(const CloseSessionRequest* req,
|
||||||
|
@ -46,6 +46,9 @@ class Master {
|
|||||||
void ExtendSession(const ExtendSessionRequest* req,
|
void ExtendSession(const ExtendSessionRequest* req,
|
||||||
ExtendSessionResponse* resp, MyClosure done);
|
ExtendSessionResponse* resp, MyClosure done);
|
||||||
|
|
||||||
|
void PartialRunSetup(const PartialRunSetupRequest* req,
|
||||||
|
PartialRunSetupResponse* resp, MyClosure done);
|
||||||
|
|
||||||
void RunStep(CallOptions* opts, const RunStepRequest* req,
|
void RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||||
RunStepResponse* resp, MyClosure done);
|
RunStepResponse* resp, MyClosure done);
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/protobuf/master.pb.h"
|
#include "tensorflow/core/protobuf/master.pb.h"
|
||||||
|
|
||||||
@ -37,6 +38,12 @@ class MasterInterface {
|
|||||||
const ExtendSessionRequest* request,
|
const ExtendSessionRequest* request,
|
||||||
ExtendSessionResponse* response) = 0;
|
ExtendSessionResponse* response) = 0;
|
||||||
|
|
||||||
|
virtual Status PartialRunSetup(CallOptions* call_options,
|
||||||
|
const PartialRunSetupRequest* request,
|
||||||
|
PartialRunSetupResponse* response) {
|
||||||
|
return errors::Unimplemented("Partial run not implemented for this master");
|
||||||
|
}
|
||||||
|
|
||||||
virtual Status RunStep(CallOptions* call_options,
|
virtual Status RunStep(CallOptions* call_options,
|
||||||
const RunStepRequest* request,
|
const RunStepRequest* request,
|
||||||
RunStepResponse* response) = 0;
|
RunStepResponse* response) = 0;
|
||||||
|
@ -50,18 +50,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// A little bit of per-step state.
|
|
||||||
struct PerStepState {
|
|
||||||
bool collect_costs = false;
|
|
||||||
bool collect_timeline = false;
|
|
||||||
bool collect_rpcs = false;
|
|
||||||
Microseconds start_micros = Microseconds(0);
|
|
||||||
Microseconds end_micros = Microseconds(0);
|
|
||||||
std::vector<StepStats> step_stats; // per partition
|
|
||||||
StepStats rpc_stats; // for RPC layer
|
|
||||||
CostGraphDef cost_graph;
|
|
||||||
};
|
|
||||||
|
|
||||||
// MasterSession wraps SimpleClientGraph in a reference counted object.
|
// MasterSession wraps SimpleClientGraph in a reference counted object.
|
||||||
// This way, MasterSession can clear up the cache mapping Run requests to
|
// This way, MasterSession can clear up the cache mapping Run requests to
|
||||||
// compiled graphs while the compiled graph is still being used.
|
// compiled graphs while the compiled graph is still being used.
|
||||||
@ -72,15 +60,38 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
|
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
|
||||||
std::unique_ptr<SimpleClientGraph> cg,
|
std::unique_ptr<SimpleClientGraph> cg,
|
||||||
const SessionOptions& session_opts,
|
const SessionOptions& session_opts,
|
||||||
StatsPublisherFactory stats_publisher_factory)
|
StatsPublisherFactory stats_publisher_factory,
|
||||||
|
SimpleGraphExecutionState* execution_state, bool is_partial)
|
||||||
: session_handle_(handle),
|
: session_handle_(handle),
|
||||||
client_graph_(std::move(cg)),
|
client_graph_(std::move(cg)),
|
||||||
bopts_(bopts),
|
bopts_(bopts),
|
||||||
session_opts_(session_opts) {
|
session_opts_(session_opts),
|
||||||
|
is_partial_(is_partial) {
|
||||||
VLOG(1) << "Created ReffedClientGraph for node with "
|
VLOG(1) << "Created ReffedClientGraph for node with "
|
||||||
<< client_graph_->graph.num_node_ids();
|
<< client_graph_->graph.num_node_ids();
|
||||||
|
|
||||||
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
|
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
|
||||||
|
|
||||||
|
// If this is a partial run we need to initialize a name to node map for
|
||||||
|
// testing that fetches are reachable.
|
||||||
|
if (is_partial) {
|
||||||
|
std::unordered_set<StringPiece, StringPiece::Hasher> names;
|
||||||
|
for (const string& input : bopts.feed_endpoints) {
|
||||||
|
TensorId id(ParseTensorName(input));
|
||||||
|
names.emplace(id.first);
|
||||||
|
}
|
||||||
|
for (const string& output : bopts.fetch_endpoints) {
|
||||||
|
TensorId id(ParseTensorName(output));
|
||||||
|
names.emplace(id.first);
|
||||||
|
}
|
||||||
|
// We use the graph from the execution_state because we want the graph
|
||||||
|
// nodes before they are rewritten replaced by the rewriter.
|
||||||
|
for (Node* n : execution_state->full_graph()->nodes()) {
|
||||||
|
if (names.count(n->name()) > 0) {
|
||||||
|
name_to_node_.insert({n->name(), n});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~ReffedClientGraph() override { DeregisterPartitions(); }
|
~ReffedClientGraph() override { DeregisterPartitions(); }
|
||||||
@ -171,7 +182,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
SimpleGraphExecutionState* execution_state,
|
SimpleGraphExecutionState* execution_state,
|
||||||
PerStepState* pss, CallOptions* opts,
|
PerStepState* pss, CallOptions* opts,
|
||||||
const RunStepRequest& req, RunStepResponse* resp,
|
const RunStepRequest& req, RunStepResponse* resp,
|
||||||
CancellationManager* cm);
|
CancellationManager* cm, const bool is_last_partial_run);
|
||||||
|
|
||||||
// Calls workers to cleanup states for the step "step_id". Calls
|
// Calls workers to cleanup states for the step "step_id". Calls
|
||||||
// `done` when all cleanup RPCs have completed.
|
// `done` when all cleanup RPCs have completed.
|
||||||
@ -185,6 +196,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
void ProcessDeviceStats(ProfileHandler* ph,
|
void ProcessDeviceStats(ProfileHandler* ph,
|
||||||
const SimpleGraphExecutionState* execution_state,
|
const SimpleGraphExecutionState* execution_state,
|
||||||
const DeviceStepStats& ds, bool is_rpc);
|
const DeviceStepStats& ds, bool is_rpc);
|
||||||
|
// Checks that the requested fetches can be computed from the provided feeds.
|
||||||
|
Status CheckFetches(const RunStepRequest& req, const RunState* run_state,
|
||||||
|
SimpleGraphExecutionState* execution_state);
|
||||||
|
|
||||||
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
|
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
|
||||||
int64 tot = 0;
|
int64 tot = 0;
|
||||||
@ -209,6 +223,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
std::unordered_set<const Node*> nodes_needing_input_mapping_;
|
std::unordered_set<const Node*> nodes_needing_input_mapping_;
|
||||||
BuildGraphOptions bopts_;
|
BuildGraphOptions bopts_;
|
||||||
const SessionOptions session_opts_;
|
const SessionOptions session_opts_;
|
||||||
|
const bool is_partial_;
|
||||||
|
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node_;
|
||||||
|
|
||||||
// Graph partitioned into per-location subgraphs.
|
// Graph partitioned into per-location subgraphs.
|
||||||
struct Part {
|
struct Part {
|
||||||
@ -483,15 +499,14 @@ class RunManyGraphs {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
|
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
Status MasterSession::ReffedClientGraph::RunPartitions(
|
Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||||
const MasterEnv* env, int64 step_id, int64 execution_count,
|
const MasterEnv* env, int64 step_id, int64 execution_count,
|
||||||
SimpleGraphExecutionState* execution_state, PerStepState* pss,
|
SimpleGraphExecutionState* execution_state, PerStepState* pss,
|
||||||
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
|
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
|
||||||
CancellationManager* cm) {
|
CancellationManager* cm, const bool is_last_partial_run) {
|
||||||
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
|
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
|
||||||
<< execution_count;
|
<< execution_count;
|
||||||
// Builds an index for feeds provided by the client.
|
// Build an index for feeds provided by the client.
|
||||||
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
|
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
|
||||||
feeds(3);
|
feeds(3);
|
||||||
|
|
||||||
@ -524,26 +539,64 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
|||||||
for (int i = 0; i < num; ++i) {
|
for (int i = 0; i < num; ++i) {
|
||||||
const Part& part = partitions_[i];
|
const Part& part = partitions_[i];
|
||||||
RunManyGraphs::Call* c = calls.get(i);
|
RunManyGraphs::Call* c = calls.get(i);
|
||||||
|
if (is_partial_) {
|
||||||
|
c->req.set_is_partial(is_partial_);
|
||||||
|
c->req.set_is_last_partial_run(is_last_partial_run);
|
||||||
|
}
|
||||||
c->req.set_graph_handle(part.graph_handle);
|
c->req.set_graph_handle(part.graph_handle);
|
||||||
c->req.set_step_id(step_id);
|
c->req.set_step_id(step_id);
|
||||||
*c->req.mutable_exec_opts() = exec_opts;
|
*c->req.mutable_exec_opts() = exec_opts;
|
||||||
// If any feeds are provided, send the feed values together
|
// If any feeds are provided, send the feed values together
|
||||||
// in the RunGraph request.
|
// in the RunGraph request.
|
||||||
for (const auto& feed_key : part.feed_key) {
|
// In the partial case, we only want to include feeds provided in the req.
|
||||||
const string& feed = feed_key.first;
|
// In the non-partial case, all feeds in the request are in the part.
|
||||||
const string& key = feed_key.second;
|
// We keep these as separate paths for now, to ensure we aren't
|
||||||
const TensorProto* val = feeds[feed];
|
// inadvertently slowing down the normal run path.
|
||||||
if (val == nullptr) {
|
if (is_partial_) {
|
||||||
return errors::InvalidArgument("No feed is provided for feed=", feed,
|
for (const auto& feed : req.feed()) {
|
||||||
", key=", key);
|
const string& name = feed.name();
|
||||||
|
auto iter = part.feed_key.find(name);
|
||||||
|
if (iter == part.feed_key.end()) {
|
||||||
|
// The provided feed must be for a different partition.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const string& key = iter->second;
|
||||||
|
const TensorProto* val = feeds[name];
|
||||||
|
if (val == nullptr) {
|
||||||
|
return errors::InvalidArgument("No feed is provided for feed=", name,
|
||||||
|
", key=", key);
|
||||||
|
}
|
||||||
|
auto* send = c->req.add_send();
|
||||||
|
send->set_key(key);
|
||||||
|
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||||
|
}
|
||||||
|
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
|
||||||
|
// For now, we just iterate through partitions to find the matching key.
|
||||||
|
for (const auto& req_fetch : req.fetch()) {
|
||||||
|
for (const auto& key_fetch : part.key_fetch) {
|
||||||
|
if (key_fetch.second == req_fetch) {
|
||||||
|
c->req.add_recv_key(key_fetch.first);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto& feed_key : part.feed_key) {
|
||||||
|
const string& feed = feed_key.first;
|
||||||
|
const string& key = feed_key.second;
|
||||||
|
const TensorProto* val = feeds[feed];
|
||||||
|
if (val == nullptr) {
|
||||||
|
return errors::InvalidArgument("No feed is provided for feed=", feed,
|
||||||
|
", key=", key);
|
||||||
|
}
|
||||||
|
auto* send = c->req.add_send();
|
||||||
|
send->set_key(key);
|
||||||
|
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||||
|
}
|
||||||
|
for (const auto& key_fetch : part.key_fetch) {
|
||||||
|
const string& key = key_fetch.first;
|
||||||
|
c->req.add_recv_key(key);
|
||||||
}
|
}
|
||||||
auto* send = c->req.add_send();
|
|
||||||
send->set_key(key);
|
|
||||||
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
|
||||||
}
|
|
||||||
for (const auto& key_fetch : part.key_fetch) {
|
|
||||||
const string& key = key_fetch.first;
|
|
||||||
c->req.add_recv_key(key);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -762,6 +815,64 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(suharshs): Merge with CheckFetches in DirectSession.
|
||||||
|
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
|
||||||
|
// on once at setup time to prevent us from computing the dependencies
|
||||||
|
// everytime.
|
||||||
|
Status MasterSession::ReffedClientGraph::CheckFetches(
|
||||||
|
const RunStepRequest& req, const RunState* run_state,
|
||||||
|
SimpleGraphExecutionState* execution_state) {
|
||||||
|
// Build the set of pending feeds that we haven't seen.
|
||||||
|
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
|
||||||
|
for (const string& feed : run_state->pending_inputs) {
|
||||||
|
TensorId id(ParseTensorName(feed));
|
||||||
|
auto it = name_to_node_.find(id.first);
|
||||||
|
if (it == name_to_node_.end()) {
|
||||||
|
return errors::NotFound("Feed ", feed, ": not found");
|
||||||
|
}
|
||||||
|
pending_feeds.insert(id);
|
||||||
|
}
|
||||||
|
for (const auto& feed : req.feed()) {
|
||||||
|
TensorId id(ParseTensorName(feed.name()));
|
||||||
|
pending_feeds.erase(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the stack with the fetch nodes.
|
||||||
|
std::vector<const Node*> stack;
|
||||||
|
for (const string& fetch : req.fetch()) {
|
||||||
|
TensorId id(ParseTensorName(fetch));
|
||||||
|
auto it = name_to_node_.find(id.first);
|
||||||
|
if (it == name_to_node_.end()) {
|
||||||
|
return errors::NotFound("Fetch ", fetch, ": not found");
|
||||||
|
}
|
||||||
|
stack.push_back(it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any tensor needed for fetches can't be in pending_feeds.
|
||||||
|
// We need to use the original full graph from execution state.
|
||||||
|
const Graph* graph = execution_state->full_graph();
|
||||||
|
std::vector<bool> visited(graph->num_node_ids(), false);
|
||||||
|
while (!stack.empty()) {
|
||||||
|
const Node* n = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
|
||||||
|
for (const Edge* in_edge : n->in_edges()) {
|
||||||
|
const Node* in_node = in_edge->src();
|
||||||
|
if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
|
||||||
|
return errors::InvalidArgument("Fetch ", in_node->name(), ":",
|
||||||
|
in_edge->src_output(),
|
||||||
|
" can't be computed from the feeds"
|
||||||
|
" that have been fed so far.");
|
||||||
|
}
|
||||||
|
if (!visited[in_node->id()]) {
|
||||||
|
visited[in_node->id()] = true;
|
||||||
|
stack.push_back(in_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Asynchronously deregisters subgraphs on the workers, without waiting for the
|
// Asynchronously deregisters subgraphs on the workers, without waiting for the
|
||||||
// result.
|
// result.
|
||||||
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
|
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
|
||||||
@ -803,6 +914,23 @@ void BuildBuildGraphOptions(const RunStepRequest& req,
|
|||||||
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
|
||||||
|
BuildGraphOptions* opts) {
|
||||||
|
for (const auto& feed : req.feed()) {
|
||||||
|
opts->feed_endpoints.push_back(feed);
|
||||||
|
}
|
||||||
|
for (const auto& fetch : req.fetch()) {
|
||||||
|
opts->fetch_endpoints.push_back(fetch);
|
||||||
|
}
|
||||||
|
for (const auto& target : req.target()) {
|
||||||
|
opts->target_nodes.push_back(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
|
||||||
|
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
|
||||||
|
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
||||||
|
}
|
||||||
|
|
||||||
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
||||||
uint64 h = 0x2b992ddfa23249d6ull;
|
uint64 h = 0x2b992ddfa23249d6ull;
|
||||||
for (const string& name : opts.feed_endpoints) {
|
for (const string& name : opts.feed_endpoints) {
|
||||||
@ -927,11 +1055,9 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MasterSession::StartStep(const RunStepRequest& req,
|
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
|
||||||
BuildGraphOptions* opts, int64* count,
|
ReffedClientGraph** rcg, bool is_partial) {
|
||||||
ReffedClientGraph** rcg) {
|
const uint64 hash = HashBuildGraphOptions(opts);
|
||||||
BuildBuildGraphOptions(req, opts);
|
|
||||||
const uint64 hash = HashBuildGraphOptions(*opts);
|
|
||||||
ReffedClientGraph* to_unref = nullptr;
|
ReffedClientGraph* to_unref = nullptr;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -944,12 +1070,12 @@ Status MasterSession::StartStep(const RunStepRequest& req,
|
|||||||
// We have not seen this subgraph before. Build the subgraph and
|
// We have not seen this subgraph before. Build the subgraph and
|
||||||
// cache it.
|
// cache it.
|
||||||
VLOG(1) << "Unseen hash " << hash << " for "
|
VLOG(1) << "Unseen hash " << hash << " for "
|
||||||
<< BuildGraphOptionsString(*opts);
|
<< BuildGraphOptionsString(opts);
|
||||||
std::unique_ptr<SimpleClientGraph> client_graph;
|
std::unique_ptr<SimpleClientGraph> client_graph;
|
||||||
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
|
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
|
||||||
auto entry =
|
auto entry = new ReffedClientGraph(
|
||||||
new ReffedClientGraph(handle_, *opts, std::move(client_graph),
|
handle_, opts, std::move(client_graph), session_opts_,
|
||||||
session_opts_, stats_publisher_factory_);
|
stats_publisher_factory_, execution_state_.get(), is_partial);
|
||||||
iter = runs_.insert({hash, entry}).first;
|
iter = runs_.insert({hash, entry}).first;
|
||||||
auto obs_iter = obsolete_.find(hash);
|
auto obs_iter = obsolete_.find(hash);
|
||||||
if (obs_iter != obsolete_.end()) {
|
if (obs_iter != obsolete_.end()) {
|
||||||
@ -979,6 +1105,47 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
|||||||
rcg_map->clear();
|
rcg_map->clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||||
|
PartialRunSetupResponse* resp) {
|
||||||
|
std::vector<string> inputs, outputs, targets;
|
||||||
|
for (const auto& feed : req->feed()) {
|
||||||
|
inputs.push_back(feed);
|
||||||
|
}
|
||||||
|
for (const auto& fetch : req->fetch()) {
|
||||||
|
outputs.push_back(fetch);
|
||||||
|
}
|
||||||
|
for (const auto& target : req->target()) {
|
||||||
|
targets.push_back(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
|
||||||
|
|
||||||
|
ReffedClientGraph* rcg = nullptr;
|
||||||
|
int64 count = 0;
|
||||||
|
|
||||||
|
// Prepare.
|
||||||
|
BuildGraphOptions opts;
|
||||||
|
BuildBuildGraphOptions(*req, &opts);
|
||||||
|
TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
|
||||||
|
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||||
|
// step_id for future use.
|
||||||
|
uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||||
|
TRACEPRINTF("stepid %llu", step_id);
|
||||||
|
|
||||||
|
rcg->Ref();
|
||||||
|
RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
partial_runs_.emplace(
|
||||||
|
std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
|
||||||
|
|
||||||
|
resp->set_partial_run_handle(handle);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
||||||
RunStepResponse* resp) {
|
RunStepResponse* resp) {
|
||||||
UpdateLastAccessTime();
|
UpdateLastAccessTime();
|
||||||
@ -986,7 +1153,12 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
++num_running_;
|
++num_running_;
|
||||||
}
|
}
|
||||||
Status status = DoRunWithLocalExecution(opts, req, resp);
|
Status status;
|
||||||
|
if (!req->partial_run_handle().empty()) {
|
||||||
|
status = DoPartialRun(opts, req, resp);
|
||||||
|
} else {
|
||||||
|
status = DoRunWithLocalExecution(opts, req, resp);
|
||||||
|
}
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
--num_running_;
|
--num_running_;
|
||||||
@ -997,23 +1169,7 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||||
const RunStepRequest* req,
|
|
||||||
RunStepResponse* resp) {
|
|
||||||
VLOG(2) << "DoRunWithLocalExecution "
|
|
||||||
<< "req: " << req->DebugString();
|
|
||||||
PerStepState pss;
|
|
||||||
pss.start_micros = Env::Default()->NowMicros();
|
|
||||||
|
|
||||||
// Prepare.
|
|
||||||
BuildGraphOptions bgopts;
|
|
||||||
ReffedClientGraph* rcg = nullptr;
|
|
||||||
int64 count = 0;
|
|
||||||
TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
|
|
||||||
|
|
||||||
// Unref "rcg" when out of scope.
|
|
||||||
core::ScopedUnref unref(rcg);
|
|
||||||
|
|
||||||
// Registers subgraphs if haven't done so.
|
// Registers subgraphs if haven't done so.
|
||||||
PartitionOptions popts;
|
PartitionOptions popts;
|
||||||
popts.node_to_loc = SplitByWorker;
|
popts.node_to_loc = SplitByWorker;
|
||||||
@ -1051,12 +1207,136 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
|||||||
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(
|
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(
|
||||||
env_, popts, rcg->client_graph()->flib_def->ToProto()));
|
env_, popts, rcg->client_graph()->flib_def->ToProto()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
|
||||||
|
RunStepResponse* resp) {
|
||||||
|
const string& prun_handle = req->partial_run_handle();
|
||||||
|
RunState* run_state = nullptr;
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
auto it = partial_runs_.find(prun_handle);
|
||||||
|
if (it == partial_runs_.end()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Must run PartialRunSetup before performing partial runs");
|
||||||
|
}
|
||||||
|
run_state = it->second.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is the first partial run, initialize the PerStepState.
|
||||||
|
if (!run_state->step_started) {
|
||||||
|
run_state->step_started = true;
|
||||||
|
PerStepState pss;
|
||||||
|
|
||||||
|
auto count = run_state->count;
|
||||||
|
pss.collect_timeline =
|
||||||
|
req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||||
|
|
||||||
|
// Build the cost model every 'build_cost_model_every' steps after skipping
|
||||||
|
// an
|
||||||
|
// initial 'build_cost_model_after' steps.
|
||||||
|
const int64 build_cost_model_after =
|
||||||
|
session_opts_.config.graph_options().build_cost_model_after();
|
||||||
|
const int64 build_cost_model_every =
|
||||||
|
session_opts_.config.graph_options().build_cost_model();
|
||||||
|
pss.collect_costs =
|
||||||
|
build_cost_model_every > 0 &&
|
||||||
|
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
||||||
|
|
||||||
|
std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
|
||||||
|
run_state->step_id, count, req->options());
|
||||||
|
if (ph) {
|
||||||
|
pss.collect_timeline = true;
|
||||||
|
pss.collect_rpcs = ph->should_collect_rpcs();
|
||||||
|
}
|
||||||
|
|
||||||
|
run_state->pss = std::move(pss);
|
||||||
|
run_state->ph = std::move(ph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure that this is a new set of feeds that are still pending.
|
||||||
|
for (const auto& feed : req->feed()) {
|
||||||
|
auto it = run_state->pending_inputs.find(feed.name());
|
||||||
|
if (it == run_state->pending_inputs.end()) {
|
||||||
|
return errors::InvalidArgument("The feed ", feed.name(),
|
||||||
|
" had already been fed.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check that this is a new set of fetches that are still pending.
|
||||||
|
for (const auto& fetch : req->fetch()) {
|
||||||
|
auto it = run_state->pending_outputs.find(fetch);
|
||||||
|
if (it == run_state->pending_outputs.end()) {
|
||||||
|
return errors::InvalidArgument("The fetch ", fetch,
|
||||||
|
" had already been fetched.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the requested fetches can be computed from the provided feeds.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
run_state->rcg->CheckFetches(*req, run_state, execution_state_.get()));
|
||||||
|
|
||||||
|
// Determine if this partial run satisfies all the pending inputs and ouputs.
|
||||||
|
for (const auto& feed : req->feed()) {
|
||||||
|
run_state->pending_inputs.erase(feed.name());
|
||||||
|
}
|
||||||
|
for (const auto& fetch : req->fetch()) {
|
||||||
|
run_state->pending_outputs.erase(fetch);
|
||||||
|
}
|
||||||
|
bool is_last_partial_run =
|
||||||
|
(run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
|
||||||
|
|
||||||
|
Status s = run_state->rcg->RunPartitions(
|
||||||
|
env_, run_state->step_id, run_state->count, execution_state_.get(),
|
||||||
|
&run_state->pss, opts, *req, resp, cancellation_manager_,
|
||||||
|
is_last_partial_run);
|
||||||
|
|
||||||
|
// Delete the run state if there is an error or all fetches are done.
|
||||||
|
if (!s.ok() || is_last_partial_run) {
|
||||||
|
ReffedClientGraph* rcg = run_state->rcg;
|
||||||
|
run_state->pss.end_micros = Env::Default()->NowMicros();
|
||||||
|
// Schedule post-processing and cleanup to be done asynchronously.
|
||||||
|
rcg->Ref();
|
||||||
|
rcg->ProcessStats(env_, run_state->step_id, &run_state->pss,
|
||||||
|
execution_state_.get(), run_state->ph.get(), *req, resp);
|
||||||
|
rcg->CleanupPartitionsAsync(
|
||||||
|
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||||
|
}
|
||||||
|
rcg->Unref();
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
partial_runs_.erase(prun_handle);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||||
|
const RunStepRequest* req,
|
||||||
|
RunStepResponse* resp) {
|
||||||
|
VLOG(2) << "DoRunWithLocalExecution "
|
||||||
|
<< "req: " << req->DebugString();
|
||||||
|
PerStepState pss;
|
||||||
|
pss.start_micros = Env::Default()->NowMicros();
|
||||||
|
|
||||||
|
// Prepare.
|
||||||
|
BuildGraphOptions bgopts;
|
||||||
|
BuildBuildGraphOptions(*req, &bgopts);
|
||||||
|
ReffedClientGraph* rcg = nullptr;
|
||||||
|
int64 count = 0;
|
||||||
|
TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
|
||||||
|
|
||||||
|
// Unref "rcg" when out of scope.
|
||||||
|
core::ScopedUnref unref(rcg);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
|
||||||
|
|
||||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||||
// step_id for future use.
|
// step_id for future use.
|
||||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||||
TRACEPRINTF("stepid %llu", step_id);
|
TRACEPRINTF("stepid %llu", step_id);
|
||||||
|
|
||||||
std::unique_ptr<ProfileHandler> ph;
|
|
||||||
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||||
|
|
||||||
// Build the cost model every 'build_cost_model_every' steps after skipping an
|
// Build the cost model every 'build_cost_model_every' steps after skipping an
|
||||||
@ -1069,15 +1349,16 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
|||||||
build_cost_model_every > 0 &&
|
build_cost_model_every > 0 &&
|
||||||
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
||||||
|
|
||||||
ph = rcg->GetProfileHandler(step_id, count, req->options());
|
std::unique_ptr<ProfileHandler> ph =
|
||||||
|
rcg->GetProfileHandler(step_id, count, req->options());
|
||||||
if (ph) {
|
if (ph) {
|
||||||
pss.collect_timeline = true;
|
pss.collect_timeline = true;
|
||||||
pss.collect_rpcs = ph->should_collect_rpcs();
|
pss.collect_rpcs = ph->should_collect_rpcs();
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
|
TF_RETURN_IF_ERROR(
|
||||||
execution_state_.get(), &pss, opts,
|
rcg->RunPartitions(env_, step_id, count, execution_state_.get(), &pss,
|
||||||
*req, resp, cancellation_manager_));
|
opts, *req, resp, cancellation_manager_, false));
|
||||||
|
|
||||||
pss.end_micros = Env::Default()->NowMicros();
|
pss.end_micros = Env::Default()->NowMicros();
|
||||||
|
|
||||||
@ -1110,4 +1391,22 @@ Status MasterSession::Close() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MasterSession::RunState::RunState(const std::vector<string>& input_names,
|
||||||
|
const std::vector<string>& output_names,
|
||||||
|
ReffedClientGraph* rcg, const uint64 step_id,
|
||||||
|
const int64 count)
|
||||||
|
: rcg(rcg), step_id(step_id), count(count) {
|
||||||
|
// Initially all the feeds and fetches are pending.
|
||||||
|
for (auto& name : input_names) {
|
||||||
|
pending_inputs.emplace(name);
|
||||||
|
}
|
||||||
|
for (auto& name : output_names) {
|
||||||
|
pending_outputs.emplace(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MasterSession::RunState::~RunState() {
|
||||||
|
if (rcg) rcg->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
@ -72,6 +73,10 @@ class MasterSession {
|
|||||||
// Extend() may block the caller thread for a long time.
|
// Extend() may block the caller thread for a long time.
|
||||||
Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
|
Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
|
||||||
|
|
||||||
|
// Setup a partial run call.
|
||||||
|
Status PartialRunSetup(const PartialRunSetupRequest* req,
|
||||||
|
PartialRunSetupResponse* resp);
|
||||||
|
|
||||||
// Run one step.
|
// Run one step.
|
||||||
Status Run(CallOptions* opts, const RunStepRequest* req,
|
Status Run(CallOptions* opts, const RunStepRequest* req,
|
||||||
RunStepResponse* resp);
|
RunStepResponse* resp);
|
||||||
@ -101,6 +106,8 @@ class MasterSession {
|
|||||||
|
|
||||||
std::atomic_ulong last_access_time_usec_;
|
std::atomic_ulong last_access_time_usec_;
|
||||||
|
|
||||||
|
std::atomic<int64> partial_run_handle_counter_ = {0};
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
|
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
|
||||||
int64 graph_version_;
|
int64 graph_version_;
|
||||||
@ -115,6 +122,36 @@ class MasterSession {
|
|||||||
RCGMap runs_ GUARDED_BY(mu_);
|
RCGMap runs_ GUARDED_BY(mu_);
|
||||||
RCGMap obsolete_ GUARDED_BY(mu_);
|
RCGMap obsolete_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
struct PerStepState {
|
||||||
|
bool collect_costs = false;
|
||||||
|
bool collect_timeline = false;
|
||||||
|
bool collect_rpcs = false;
|
||||||
|
Microseconds start_micros = Microseconds(0);
|
||||||
|
Microseconds end_micros = Microseconds(0);
|
||||||
|
std::vector<StepStats> step_stats; // per partition
|
||||||
|
StepStats rpc_stats; // for RPC layer
|
||||||
|
CostGraphDef cost_graph;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RunState {
|
||||||
|
std::unordered_set<string> pending_inputs;
|
||||||
|
std::unordered_set<string> pending_outputs;
|
||||||
|
ReffedClientGraph* rcg = nullptr;
|
||||||
|
uint64 step_id;
|
||||||
|
int64 count = 0;
|
||||||
|
PerStepState pss;
|
||||||
|
std::unique_ptr<ProfileHandler> ph;
|
||||||
|
bool step_started = false;
|
||||||
|
|
||||||
|
RunState(const std::vector<string>& input_names,
|
||||||
|
const std::vector<string>& output_names, ReffedClientGraph* rcg,
|
||||||
|
const uint64 step_id, const int64 count);
|
||||||
|
|
||||||
|
~RunState();
|
||||||
|
};
|
||||||
|
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
|
||||||
|
GUARDED_BY(mu_);
|
||||||
|
|
||||||
// Active RunStep calls.
|
// Active RunStep calls.
|
||||||
condition_variable num_running_is_zero_;
|
condition_variable num_running_is_zero_;
|
||||||
int32 num_running_ GUARDED_BY(mu_) = 0;
|
int32 num_running_ GUARDED_BY(mu_) = 0;
|
||||||
@ -131,14 +168,18 @@ class MasterSession {
|
|||||||
// Private dtor. The client must call Close().
|
// Private dtor. The client must call Close().
|
||||||
virtual ~MasterSession();
|
virtual ~MasterSession();
|
||||||
|
|
||||||
Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
|
Status StartStep(const BuildGraphOptions& opts, int64* count,
|
||||||
int64* count, ReffedClientGraph** graph);
|
ReffedClientGraph** graph, bool is_partial);
|
||||||
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||||
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
|
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
|
||||||
RunStepResponse* resp);
|
RunStepResponse* resp);
|
||||||
|
Status DoPartialRun(CallOptions* opts, const RunStepRequest* req,
|
||||||
|
RunStepResponse* resp);
|
||||||
void UpdateLastAccessTime();
|
void UpdateLastAccessTime();
|
||||||
|
|
||||||
|
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
|
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ class OpKernel {
|
|||||||
// We allow legacy scalars within Google up until GraphDef version 6.
|
// We allow legacy scalars within Google up until GraphDef version 6.
|
||||||
// TODO(irving): Remove when we can drop support for GraphDef version 5.
|
// TODO(irving): Remove when we can drop support for GraphDef version 5.
|
||||||
bool allow_legacy_scalars() const {
|
bool allow_legacy_scalars() const {
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
|
||||||
return graph_def_version_ < 6;
|
return graph_def_version_ < 6;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
|
@ -1136,8 +1136,9 @@ tf_kernel_libraries(
|
|||||||
":eigen_helpers",
|
":eigen_helpers",
|
||||||
":image_resizer_state",
|
":image_resizer_state",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:gif_internal",
|
||||||
"//tensorflow/core:image_ops_op_lib",
|
"//tensorflow/core:image_ops_op_lib",
|
||||||
"//tensorflow/core:jpeg",
|
"//tensorflow/core:jpeg_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -2099,11 +2100,13 @@ tf_kernel_libraries(
|
|||||||
"count_up_to_op",
|
"count_up_to_op",
|
||||||
"dense_update_ops",
|
"dense_update_ops",
|
||||||
"scatter_op",
|
"scatter_op",
|
||||||
|
"scatter_nd_op",
|
||||||
"variable_ops",
|
"variable_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":assign_op",
|
":assign_op",
|
||||||
":bounds_check",
|
":bounds_check",
|
||||||
|
":fill_functor",
|
||||||
":scatter_functor",
|
":scatter_functor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -2117,6 +2120,7 @@ tf_cc_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["scatter_op_test.cc"],
|
srcs = ["scatter_op_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":fill_functor",
|
||||||
":ops_testutil",
|
":ops_testutil",
|
||||||
":ops_util",
|
":ops_util",
|
||||||
":scatter_op",
|
":scatter_op",
|
||||||
@ -2129,6 +2133,23 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "scatter_nd_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["scatter_nd_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
":scatter_nd_op",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_libraries(
|
tf_kernel_libraries(
|
||||||
name = "string",
|
name = "string",
|
||||||
prefixes = [
|
prefixes = [
|
||||||
@ -2571,6 +2592,7 @@ filegroup(
|
|||||||
"debug_ops.*",
|
"debug_ops.*",
|
||||||
# Ops excluded because they do not build correctly for Android.
|
# Ops excluded because they do not build correctly for Android.
|
||||||
# See b/29213790
|
# See b/29213790
|
||||||
|
"scatter_nd_op*",
|
||||||
"sparse_matmul_op.*",
|
"sparse_matmul_op.*",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -64,7 +64,7 @@ class TestFileSystem : public NullFileSystem {
|
|||||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
|
std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
|
||||||
float val = 0;
|
float val = 0;
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(fname, &scheme, &host, &path);
|
io::ParseURI(fname, &scheme, &host, &path);
|
||||||
// For the tests create in-memory regions with float values equal to the
|
// For the tests create in-memory regions with float values equal to the
|
||||||
// region name.
|
// region name.
|
||||||
if (path == "/2") {
|
if (path == "/2") {
|
||||||
|
@ -46,25 +46,6 @@ namespace functor {
|
|||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
using random::SingleSampleAdapter;
|
using random::SingleSampleAdapter;
|
||||||
|
|
||||||
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
|
||||||
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
|
||||||
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
|
||||||
//
|
|
||||||
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
|
||||||
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
|
||||||
// both tails may have reduced accuracy.
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct TruncatedNormalFunctor {
|
|
||||||
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
|
||||||
int64 samples_per_batch, int64 num_elements,
|
|
||||||
typename TTypes<T>::ConstFlat means,
|
|
||||||
typename TTypes<T>::ConstFlat stddevs,
|
|
||||||
typename TTypes<T>::ConstFlat minvals,
|
|
||||||
typename TTypes<T>::ConstFlat maxvals,
|
|
||||||
const random::PhiloxRandom& gen,
|
|
||||||
typename TTypes<T>::Flat output);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TruncatedNormalFunctor<CPUDevice, T> {
|
struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||||
static const int kMaxIterations = 100;
|
static const int kMaxIterations = 100;
|
||||||
@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
|
|
||||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
// We always generate at most 4 samples.
|
// We always generate at most 4 samples.
|
||||||
tensorflow::random::Array<T, 4> z;
|
Eigen::array<T, 4> z;
|
||||||
tensorflow::random::Array<T, 4> g;
|
Eigen::array<T, 4> g;
|
||||||
|
|
||||||
for (int64 b = start_batch; b < limit_batch; ++b) {
|
for (int64 b = start_batch; b < limit_batch; ++b) {
|
||||||
// We are passed a flat array for each of the parameter tensors.
|
// We are passed a flat array for each of the parameter tensors.
|
||||||
@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
if (diff < cutoff) {
|
if (diff < cutoff) {
|
||||||
// Sample from a uniform distribution on [normMin, normMax].
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
T plusFactor;
|
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||||
if (normMin < T(0)) {
|
|
||||||
// normMax > 0 because it is flipped otherwise.
|
|
||||||
plusFactor = T(0);
|
|
||||||
} else {
|
|
||||||
plusFactor = normMin * normMin;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (sample < limit_sample) {
|
while (sample < limit_sample) {
|
||||||
const auto rand = dist(&gen_copy);
|
const auto rand = dist(&gen_copy);
|
||||||
@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
|
|||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER(TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
|
||||||
|
|
||||||
|
TF_CALL_half(REGISTER);
|
||||||
|
TF_CALL_float(REGISTER);
|
||||||
|
TF_CALL_double(REGISTER);
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -16,14 +16,35 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||||
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class OpKernelContext;
|
class OpKernelContext;
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
||||||
|
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
||||||
|
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
||||||
|
//
|
||||||
|
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
||||||
|
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
||||||
|
// both tails may have reduced accuracy.
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct TruncatedNormalFunctor;
|
struct TruncatedNormalFunctor {
|
||||||
|
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output);
|
||||||
|
|
||||||
|
static const int kMaxIterations = 100;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -0,0 +1,214 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
#define UNROLL _Pragma("unroll")
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpKernelContext;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void __launch_bounds__(1024)
|
||||||
|
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
const T* means, bool single_mean, const T* stddevs,
|
||||||
|
bool single_stddev, const T* minvals,
|
||||||
|
bool single_minval, const T* maxvals,
|
||||||
|
bool single_maxval, int64 kMaxIterations) {
|
||||||
|
const int32 max_samples_per_item = 2 * kMaxIterations;
|
||||||
|
// Initial offset as given by CUDA_1D_KERNEL_LOOP.
|
||||||
|
const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
gen.Skip(max_samples_per_item * initial_offset);
|
||||||
|
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
|
||||||
|
Uniform dist;
|
||||||
|
const int kDistSize = Uniform::kResultElementCount;
|
||||||
|
const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
|
||||||
|
|
||||||
|
// We skip the total number of threads to get to the next element. To produce
|
||||||
|
// deterministic results between devices, each element in the output array
|
||||||
|
// skips max_samples_per_item in the generator. Then after generating this
|
||||||
|
// item, we need to skip the samples for one element for every thread to get
|
||||||
|
// to the next element that we actually process.
|
||||||
|
const int32 samples_between_processed_elements =
|
||||||
|
max_samples_per_item * (gridDim.x * blockDim.x);
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(offset, num_elements) {
|
||||||
|
// Track how many more samples we need to skip before we process the next
|
||||||
|
// element.
|
||||||
|
int32 remaining_samples = samples_between_processed_elements;
|
||||||
|
|
||||||
|
const int64 batch_id = offset / samples_per_batch;
|
||||||
|
T mean = means[single_mean ? 0 : batch_id];
|
||||||
|
const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
|
||||||
|
T minval = minvals[single_minval ? 0 : batch_id];
|
||||||
|
T maxval = maxvals[single_maxval ? 0 : batch_id];
|
||||||
|
|
||||||
|
// Flip the distribution if we can make the lower bound positive.
|
||||||
|
T stddev;
|
||||||
|
if (Eigen::numext::isinf(minval) || maxval < mean) {
|
||||||
|
// Reverse all calculations. normMin and normMax will be flipped.
|
||||||
|
// std::swap is a host function (not available in CUDA).
|
||||||
|
T temp = minval;
|
||||||
|
minval = maxval;
|
||||||
|
maxval = temp;
|
||||||
|
stddev = -input_stddev;
|
||||||
|
} else {
|
||||||
|
stddev = input_stddev;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate normalized samples, then scale them.
|
||||||
|
const T normMin = (minval - mean) / stddev;
|
||||||
|
const T normMax = (maxval - mean) / stddev;
|
||||||
|
|
||||||
|
// Determine the method to use.
|
||||||
|
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
|
||||||
|
const T cutoff =
|
||||||
|
T(2) *
|
||||||
|
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||||
|
(normMin + sqrtFactor);
|
||||||
|
const T diff = normMax - normMin;
|
||||||
|
|
||||||
|
// Validate the normalized min and max, because the originals may have been
|
||||||
|
// flipped already.
|
||||||
|
if (!(input_stddev > T(0) && normMin < normMax &&
|
||||||
|
(Eigen::numext::isfinite(normMin) ||
|
||||||
|
Eigen::numext::isfinite(normMax)))) {
|
||||||
|
data[offset] = quietNaN;
|
||||||
|
} else if (diff < cutoff) {
|
||||||
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
|
// We always generate at most 4 samples.
|
||||||
|
Eigen::array<T, 4> z;
|
||||||
|
Eigen::array<T, 4> g;
|
||||||
|
|
||||||
|
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||||
|
|
||||||
|
int numIterations = 0;
|
||||||
|
while (numIterations < kMaxIterations) {
|
||||||
|
const auto rand = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
z[i] = rand[i] * diff + normMin;
|
||||||
|
}
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto u = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
if (u[i] <= Eigen::numext::exp(g[i]) ||
|
||||||
|
numIterations + 1 >= kMaxIterations) {
|
||||||
|
// Accept the sample z.
|
||||||
|
// If we run out of iterations, just use the current uniform
|
||||||
|
// sample. Emperically, the probability of accepting each sample
|
||||||
|
// is at least 50% for typical inputs, so we will always accept
|
||||||
|
// by 100 iterations.
|
||||||
|
// This introduces a slight inaccuracy when at least one bound
|
||||||
|
// is large, minval is negative and maxval is positive.
|
||||||
|
data[offset] = z[i] * stddev + mean;
|
||||||
|
// Break out of the nested loop by updating numIterations.
|
||||||
|
numIterations = kMaxIterations;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
numIterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Sample from an exponential distribution with alpha maximizing
|
||||||
|
// acceptance probability, offset by normMin from the origin.
|
||||||
|
// Accept only if less than normMax.
|
||||||
|
const T alpha =
|
||||||
|
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
|
||||||
|
int numIterations = 0;
|
||||||
|
while (numIterations < kMaxIterations) {
|
||||||
|
auto rand = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i += 2) {
|
||||||
|
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||||
|
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||||
|
const T g = Eigen::numext::exp(-x * x / 2.0);
|
||||||
|
const T u = rand[i + 1];
|
||||||
|
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
|
||||||
|
data[offset] = z * stddev + mean;
|
||||||
|
// Break out of the nested loop by updating numIterations.
|
||||||
|
numIterations = kMaxIterations;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
numIterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gen.Skip(remaining_samples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partial specialization for GPU
|
||||||
|
template <typename T>
|
||||||
|
struct TruncatedNormalFunctor<GPUDevice, T> {
|
||||||
|
static const int kMaxIterations = 100;
|
||||||
|
|
||||||
|
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output) {
|
||||||
|
const auto config = GetCudaLaunchConfig(num_elements, d);
|
||||||
|
|
||||||
|
TruncatedNormalKernel<
|
||||||
|
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
gen, output.data(), num_batches, samples_per_batch, num_elements,
|
||||||
|
means.data(), means.dimension(0) == 1, stddevs.data(),
|
||||||
|
stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
|
||||||
|
maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Explicit instantiation of the GPU distributions functors
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, float>;
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, double>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
|
|||||||
BM_PTruncatedNormalDev(cpu, 1000, 1000);
|
BM_PTruncatedNormalDev(cpu, 1000, 1000);
|
||||||
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
|
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
|
||||||
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
|
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
|
||||||
|
BM_PTruncatedNormalDev(gpu, 1000, 1000);
|
||||||
|
BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
|
||||||
|
BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
402
tensorflow/core/kernels/scatter_nd_op.cc
Normal file
402
tensorflow/core/kernels/scatter_nd_op.cc
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/state_ops.cc.
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// Check whether updates.shape = indices.shape[0] + params.shape[IXDIM:]
|
||||||
|
static bool ValidUpdateShape(const TensorShape& params_shape,
|
||||||
|
const Tensor& indices, const Tensor& updates) {
|
||||||
|
int64 indices_nd = 1;
|
||||||
|
if (indices.dims() > 1) {
|
||||||
|
indices_nd = indices.dim_size(1);
|
||||||
|
}
|
||||||
|
for (int d = indices_nd; d < params_shape.dims(); d++) {
|
||||||
|
if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Index>
|
||||||
|
static void PrepareAndValidateInputs(OpKernelContext* c,
|
||||||
|
const TensorShape& params_shape,
|
||||||
|
const Tensor& indices,
|
||||||
|
const Tensor& updates, int64* indices_nd,
|
||||||
|
Index* num_updates, Index* slice_size) {
|
||||||
|
const TensorShape& indices_shape(indices.shape());
|
||||||
|
const TensorShape& updates_shape(updates.shape());
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, TensorShapeUtils::IsVectorOrHigher(params_shape),
|
||||||
|
errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ",
|
||||||
|
params_shape.DebugString()));
|
||||||
|
|
||||||
|
OP_REQUIRES(c, params_shape.num_elements() >= 0 ||
|
||||||
|
(indices.NumElements() == 0 && updates.NumElements() == 0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Indices and updates specified for empty output", " shape"));
|
||||||
|
|
||||||
|
OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The outermost dimension of updates and indices ",
|
||||||
|
"must match. Got indices.shape ", indices_shape.DebugString(),
|
||||||
|
", updates.shape ", updates_shape.DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, ValidUpdateShape(params_shape, indices, updates),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ",
|
||||||
|
"got updates.shape ", updates_shape.DebugString(), ", indices.shape ",
|
||||||
|
indices_shape.DebugString(), ", params_shape ",
|
||||||
|
params_shape.DebugString()));
|
||||||
|
// Check that we have enough index space
|
||||||
|
const int64 N_big = indices.NumElements();
|
||||||
|
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"indices has too many elements for ",
|
||||||
|
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
|
||||||
|
N_big, " > ", std::numeric_limits<Index>::max()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, params_shape.dim_size(0) <= std::numeric_limits<Index>::max(),
|
||||||
|
errors::InvalidArgument("params_shape[0] too large for ",
|
||||||
|
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||||
|
" indexing: ", params_shape.dim_size(0), " > ",
|
||||||
|
std::numeric_limits<Index>::max()));
|
||||||
|
|
||||||
|
// Calculate the number of dimensions in indices
|
||||||
|
*indices_nd = 1;
|
||||||
|
if (indices_shape.dims() > 1) {
|
||||||
|
*indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number of elements that make up each slice of our updated
|
||||||
|
// tensor. This allows us to work with flattened tensors and copy over whole
|
||||||
|
// slices at a time.
|
||||||
|
Index total_nd = params_shape.dims();
|
||||||
|
|
||||||
|
int64 slice_size_big = 1;
|
||||||
|
for (int64 i = *indices_nd; i < total_nd; ++i) {
|
||||||
|
slice_size_big *= params_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(c, slice_size_big <= std::numeric_limits<Index>::max(),
|
||||||
|
errors::InvalidArgument("slice size is too large for indexing: ",
|
||||||
|
slice_size_big, " > ",
|
||||||
|
std::numeric_limits<Index>::max()));
|
||||||
|
|
||||||
|
*slice_size = static_cast<Index>(slice_size_big);
|
||||||
|
|
||||||
|
const int64 safe_indices_nd = (*indices_nd < 1) ? 1 : *indices_nd;
|
||||||
|
*num_updates = indices_shape.num_elements() / safe_indices_nd;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device, typename T, typename Index>
|
||||||
|
class ScatterNdOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
const DataType dt = DataTypeToEnum<T>::v();
|
||||||
|
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||||
|
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) override {
|
||||||
|
const Tensor& indices = c->input(0);
|
||||||
|
const Tensor& updates = c->input(1);
|
||||||
|
const Tensor& shape_input = c->input(2);
|
||||||
|
|
||||||
|
OP_REQUIRES(c, shape_input.dims() == 1,
|
||||||
|
errors::InvalidArgument("Shape must be a vector"));
|
||||||
|
auto vec = shape_input.flat<Index>();
|
||||||
|
TensorShape shape;
|
||||||
|
TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape);
|
||||||
|
|
||||||
|
int64 indices_nd;
|
||||||
|
Index num_updates;
|
||||||
|
Index slice_size;
|
||||||
|
PrepareAndValidateInputs<Index>(c, shape, indices, updates, &indices_nd,
|
||||||
|
&num_updates, &slice_size);
|
||||||
|
if (!c->status().ok()) return;
|
||||||
|
|
||||||
|
Tensor scratch;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||||
|
|
||||||
|
auto scratch_scalar = scratch.scalar<Index>();
|
||||||
|
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||||
|
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||||
|
|
||||||
|
Index bad_i = -1;
|
||||||
|
switch (indices_nd) {
|
||||||
|
#define PARAMS_CASE(IXDIM) \
|
||||||
|
case IXDIM: { \
|
||||||
|
Tensor* out = nullptr; \
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
|
||||||
|
functor::SetZeroFunctor<Device, T> fill; \
|
||||||
|
fill(c->eigen_device<Device>(), out->flat<T>()); \
|
||||||
|
if (shape.num_elements() > 0) { \
|
||||||
|
auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||||
|
functor::ScatterNdFunctor<Device, T, Index, \
|
||||||
|
scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
|
||||||
|
functor; \
|
||||||
|
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||||
|
output_flat, indices_flat, updates_flat, output_flat); \
|
||||||
|
} \
|
||||||
|
} break
|
||||||
|
PARAMS_CASE(0);
|
||||||
|
PARAMS_CASE(1);
|
||||||
|
PARAMS_CASE(2);
|
||||||
|
PARAMS_CASE(3);
|
||||||
|
PARAMS_CASE(4);
|
||||||
|
PARAMS_CASE(5);
|
||||||
|
#undef PARAMS_CASE
|
||||||
|
default:
|
||||||
|
OP_REQUIRES(c, false,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only indices.shape[-1] values between 0 and 5 "
|
||||||
|
"are currently supported. Requested rank: ",
|
||||||
|
indices_nd));
|
||||||
|
}
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, bad_i < 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
|
||||||
|
" = [", str_util::Join(gtl::ArraySlice<Index>(
|
||||||
|
&indices_flat(bad_i, 0), indices_nd),
|
||||||
|
", "),
|
||||||
|
"] does not index into ", shape.DebugString()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T, typename Index,
|
||||||
|
scatter_nd_op::UpdateOp op>
|
||||||
|
class ScatterNdUpdateOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
const DataType dt = DataTypeToEnum<T>::v();
|
||||||
|
const DataType dt_ref = DataTypeToEnum<T>::ref();
|
||||||
|
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||||
|
OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) override {
|
||||||
|
if (use_exclusive_lock_) {
|
||||||
|
// Hold mutex while we apply updates
|
||||||
|
mutex_lock l(*c->input_ref_mutex(0));
|
||||||
|
DoCompute(c);
|
||||||
|
} else {
|
||||||
|
DoCompute(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_exclusive_lock_;
|
||||||
|
|
||||||
|
void DoCompute(OpKernelContext* c) {
|
||||||
|
Tensor params = c->mutable_input(0, use_exclusive_lock_);
|
||||||
|
const Tensor& indices = c->input(1);
|
||||||
|
const Tensor& updates = c->input(2);
|
||||||
|
const TensorShape& params_shape(params.shape());
|
||||||
|
|
||||||
|
int64 indices_nd;
|
||||||
|
Index num_updates;
|
||||||
|
Index slice_size;
|
||||||
|
|
||||||
|
OP_REQUIRES(c, params.IsInitialized(),
|
||||||
|
errors::FailedPrecondition("Null ref for params"));
|
||||||
|
PrepareAndValidateInputs<Index>(c, params_shape, indices, updates,
|
||||||
|
&indices_nd, &num_updates, &slice_size);
|
||||||
|
if (!c->status().ok()) return;
|
||||||
|
|
||||||
|
Tensor scratch;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||||
|
|
||||||
|
auto scratch_scalar = scratch.scalar<Index>();
|
||||||
|
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||||
|
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||||
|
|
||||||
|
Index bad_i = -1;
|
||||||
|
c->forward_ref_input_to_ref_output(0, 0);
|
||||||
|
switch (indices_nd) {
|
||||||
|
#define PARAMS_CASE(IXDIM) \
|
||||||
|
case IXDIM: { \
|
||||||
|
auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||||
|
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
|
||||||
|
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||||
|
params_flat, indices_flat, updates_flat, params_flat); \
|
||||||
|
} break
|
||||||
|
PARAMS_CASE(0);
|
||||||
|
PARAMS_CASE(1);
|
||||||
|
PARAMS_CASE(2);
|
||||||
|
PARAMS_CASE(3);
|
||||||
|
PARAMS_CASE(4);
|
||||||
|
PARAMS_CASE(5);
|
||||||
|
#undef PARAMS_CASE
|
||||||
|
default:
|
||||||
|
OP_REQUIRES(c, false,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only indices.shape[-1] values between 1 and 5 "
|
||||||
|
"are currently supported. Requested rank: ",
|
||||||
|
indices_nd));
|
||||||
|
}
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, bad_i < 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
|
||||||
|
" = [", str_util::Join(gtl::ArraySlice<Index>(
|
||||||
|
&indices_flat(bad_i, 0), indices_nd),
|
||||||
|
", "),
|
||||||
|
"] is not in [0, ", params.dim_size(0), ")"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name(name) \
|
||||||
|
.Device(DEVICE_##dev) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
|
ScatterNdOp<dev##Device, type, index_type>)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
|
||||||
|
op) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name(name) \
|
||||||
|
.Device(DEVICE_##dev) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
|
ScatterNdUpdateOp<dev##Device, type, index_type, op>)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
|
||||||
|
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
|
||||||
|
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
|
||||||
|
scatter_nd_op::UpdateOp::ADD); \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
|
||||||
|
scatter_nd_op::UpdateOp::SUB); \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
|
||||||
|
scatter_nd_op::UpdateOp::MUL); \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
|
||||||
|
scatter_nd_op::UpdateOp::DIV);
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND(type, dev) \
|
||||||
|
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
|
||||||
|
scatter_nd_op::UpdateOp::ASSIGN);
|
||||||
|
|
||||||
|
// Registers CPU kernels.
|
||||||
|
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
|
||||||
|
REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE(type, CPU);
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
|
||||||
|
|
||||||
|
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||||
|
|
||||||
|
// Registers GPU kernels.
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
|
||||||
|
REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
|
||||||
|
REGISTER_SCATTER_ND_UPDATE(type, GPU);
|
||||||
|
|
||||||
|
// TODO(simister): Re-enable when GPU support is working.
|
||||||
|
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU);
|
||||||
|
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU);
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_SCATTER_ND_ADD
|
||||||
|
#undef REGISTER_SCATTER_ND_ADD_SUB
|
||||||
|
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
|
||||||
|
#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
|
||||||
|
#undef REGISTER_SCATTER_ND_UPDATE
|
||||||
|
#undef REGISTER_SCATTER_ND_UPDATE_CPU
|
||||||
|
#undef REGISTER_SCATTER_ND_UPDATE_GPU
|
||||||
|
#undef REGISTER_SCATTER_ND_KERNEL
|
||||||
|
#undef REGISTER_SCATTER_ND_KERNEL_INDEX
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
|
||||||
|
template <> \
|
||||||
|
Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
|
||||||
|
OpKernelContext* c, const GPUDevice& d, \
|
||||||
|
typename TTypes<T, IXDIM>::Tensor params, \
|
||||||
|
typename TTypes<Index, 2>::ConstTensor indices, \
|
||||||
|
typename TTypes<T, 2>::ConstTensor updates); \
|
||||||
|
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
|
||||||
|
DECLARE_GPU_SPECS_OP(T, Index, op, 5)
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||||
|
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
|
||||||
|
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
|
||||||
|
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
|
||||||
|
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
|
||||||
|
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int64);
|
||||||
|
|
||||||
|
// TODO(simister): Re-enable when GPU support is working.
|
||||||
|
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPECS
|
||||||
|
#undef DECLARE_GPU_SPECS_INDEX
|
||||||
|
#undef DECLARE_GPU_SPECS_OPS
|
||||||
|
#undef DECLARE_GPU_SPECS_OP
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
62
tensorflow/core/kernels/scatter_nd_op.h
Normal file
62
tensorflow/core/kernels/scatter_nd_op.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
||||||
|
#define TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
class OpKernelContext;
|
||||||
|
|
||||||
|
namespace scatter_nd_op {
|
||||||
|
|
||||||
|
enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
|
||||||
|
|
||||||
|
} // namespace scatter_nd_op
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
// Functor used by ScatterOp to do the computations.
|
||||||
|
template <typename Device, typename T, typename Index,
|
||||||
|
scatter_nd_op::UpdateOp op, int IXDIM>
|
||||||
|
struct ScatterNdFunctor {
|
||||||
|
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
|
||||||
|
Index operator()(const Device& d, const Index slice_size,
|
||||||
|
typename TTypes<Index>::Scalar Tscratch,
|
||||||
|
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||||
|
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||||
|
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||||
|
typename TTypes<T, IXDIM + 1>::Tensor Toutput);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
|
224
tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Normal file
224
tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
||||||
|
|
||||||
|
// Functor definitions for ScatterND ops, must be compilable by nvcc.
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
class OpKernelContext;
|
||||||
|
|
||||||
|
// Specialization of UpdateExecutor to CPU
|
||||||
|
namespace generator {
|
||||||
|
|
||||||
|
template <typename T, typename Index, scatter_nd_op::UpdateOp op>
|
||||||
|
class UpdateExecutor {
|
||||||
|
public:
|
||||||
|
static void Update(T* input, const T* updates, T* output, Index slice_size);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> {
|
||||||
|
public:
|
||||||
|
static void Update(T* /* unused */, const T* updates, T* output,
|
||||||
|
Index slice_size) {
|
||||||
|
std::copy_n(updates, slice_size, output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> {
|
||||||
|
public:
|
||||||
|
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||||
|
std::transform(input, input + slice_size, updates, output, std::plus<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> {
|
||||||
|
public:
|
||||||
|
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||||
|
std::transform(input, input + slice_size, updates, output, std::minus<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> {
|
||||||
|
public:
|
||||||
|
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||||
|
std::transform(input, input + slice_size, updates, output,
|
||||||
|
std::multiplies<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> {
|
||||||
|
public:
|
||||||
|
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||||
|
std::transform(input, input + slice_size, updates, output,
|
||||||
|
std::divides<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||||
|
class ScatterNdSliceGenerator {
|
||||||
|
public:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator(
|
||||||
|
const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||||
|
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||||
|
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||||
|
typename TTypes<T, IXDIM + 1>::Tensor Toutput,
|
||||||
|
std::atomic<Index>* error_loc)
|
||||||
|
: slice_size_(slice_size),
|
||||||
|
Tparams_(Tparams),
|
||||||
|
Tindices_(Tindices),
|
||||||
|
Tupdates_(Tupdates),
|
||||||
|
Toutput_(Toutput),
|
||||||
|
error_loc_(error_loc) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC bool GenerateIndices(
|
||||||
|
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
|
||||||
|
(*ix)[IXDIM] = 0;
|
||||||
|
bool out_of_bounds = false;
|
||||||
|
for (int i = 0; i < IXDIM; ++i) {
|
||||||
|
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
|
||||||
|
(*ix)[i] = ix_i;
|
||||||
|
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||||
|
}
|
||||||
|
return out_of_bounds;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
|
||||||
|
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||||
|
auto loc = loc_array[0];
|
||||||
|
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params;
|
||||||
|
Eigen::array<Eigen::DenseIndex, 2> ix_updates;
|
||||||
|
ix_updates[0] = loc;
|
||||||
|
ix_updates[1] = 0;
|
||||||
|
const bool out_of_bounds = GenerateIndices(loc, &ix_params);
|
||||||
|
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||||
|
error_loc_->store(loc);
|
||||||
|
} else {
|
||||||
|
UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params),
|
||||||
|
&Tupdates_(ix_updates),
|
||||||
|
&Toutput_(ix_params), slice_size_);
|
||||||
|
}
|
||||||
|
return static_cast<int32>(0); // Return something...
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const Index slice_size_;
|
||||||
|
mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_;
|
||||||
|
const typename TTypes<Index, 2>::ConstTensor Tindices_;
|
||||||
|
const typename TTypes<T, 2>::ConstTensor Tupdates_;
|
||||||
|
mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_;
|
||||||
|
std::atomic<Index>* error_loc_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generator
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
// Implementation of update functor for CPU.
|
||||||
|
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||||
|
struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
|
||||||
|
Index operator()(const CPUDevice& d, const Index slice_size,
|
||||||
|
typename TTypes<Index>::Scalar Tscratch,
|
||||||
|
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||||
|
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||||
|
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||||
|
typename TTypes<T, IXDIM + 1>::Tensor Toutput) {
|
||||||
|
std::atomic<Index> error_loc(-1);
|
||||||
|
|
||||||
|
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
|
||||||
|
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||||
|
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
|
||||||
|
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
|
||||||
|
#else
|
||||||
|
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
|
||||||
|
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
|
||||||
|
broadcast_dims.set(0, batch_size);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator(
|
||||||
|
slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc);
|
||||||
|
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
|
||||||
|
.broadcast(broadcast_dims)
|
||||||
|
.generate(generator)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||||
|
// otherwise it returns the location of an OOB index in Tindices.
|
||||||
|
return error_loc.load();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
|
||||||
|
template Index \
|
||||||
|
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
|
||||||
|
const CPUDevice& d, const Index slice_size, \
|
||||||
|
typename TTypes<Index>::Scalar Tscratch, \
|
||||||
|
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Tparams, \
|
||||||
|
typename TTypes<Index, 2>::ConstTensor Tindices, \
|
||||||
|
typename TTypes<T, 2>::ConstTensor Tupdates, \
|
||||||
|
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Toutput)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_INDEX(type, op) \
|
||||||
|
REGISTER_SCATTER_ND_FULL(type, int32, op); \
|
||||||
|
REGISTER_SCATTER_ND_FULL(type, int64, op)
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_UPDATE(type) \
|
||||||
|
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
|
||||||
|
|
||||||
|
#define REGISTER_SCATTER_ND_MATH(type) \
|
||||||
|
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
|
||||||
|
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); \
|
||||||
|
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \
|
||||||
|
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV);
|
||||||
|
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
|
||||||
|
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
|
||||||
|
|
||||||
|
#undef REGISTER_SCATTER_ND_MATH
|
||||||
|
#undef REGISTER_SCATTER_ND_UPDATE
|
||||||
|
#undef REGISTER_SCATTER_ND_INDEX
|
||||||
|
#undef REGISTER_SCATTER_ND_FULL
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 0
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 1
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 2
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 3
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
Normal file
18
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 4
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
19
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
Normal file
19
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define CPU_PROVIDED_IXDIM 5
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||||
|
#undef CPU_PROVIDED_IXDIM
|
320
tensorflow/core/kernels/scatter_nd_op_test.cc
Normal file
320
tensorflow/core/kernels/scatter_nd_op_test.cc
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ScatterNdUpdateOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
|
||||||
|
.Input(FakeInput(variable_ref_type))
|
||||||
|
.Input(FakeInput(index_type))
|
||||||
|
.Input(FakeInput(RemoveRefType(variable_ref_type)))
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
|
||||||
|
MakeOp(DT_STRING_REF, DT_INT32);
|
||||||
|
AddInputFromArray<string>(TensorShape({1}), {"Brain"});
|
||||||
|
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||||
|
AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({1}));
|
||||||
|
test::FillValues<string>(&expected, {"TensorFlow"});
|
||||||
|
test::ExpectTensorEqual<string>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
|
||||||
|
MakeOp(DT_BOOL_REF, DT_INT32);
|
||||||
|
AddInputFromArray<bool>(TensorShape({1}), {false});
|
||||||
|
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||||
|
AddInputFromArray<bool>(TensorShape({1}), {true});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
|
||||||
|
test::FillValues<bool>(&expected, {true});
|
||||||
|
test::ExpectTensorEqual<bool>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||||
|
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||||
|
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||||
|
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||||
|
10002, 0, 0, 0, 777, 778, 779});
|
||||||
|
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT64);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
|
||||||
|
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||||
|
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||||
|
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||||
|
10002, 0, 0, 0, 777, 778, 779});
|
||||||
|
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({0}), {});
|
||||||
|
AddInputFromArray<int32>(TensorShape({0}), {});
|
||||||
|
AddInputFromArray<float>(TensorShape({0}), {});
|
||||||
|
Status s = RunOpKernel();
|
||||||
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
|
.contains("Output must not have 0 elements, got shape: "))
|
||||||
|
<< s;
|
||||||
|
}*/
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({1}), {3});
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {101});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
|
||||||
|
test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
|
||||||
|
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||||
|
AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
|
||||||
|
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
|
||||||
|
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, HigherRank) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
|
||||||
|
AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the new state of the input
|
||||||
|
Tensor params_tensor = *mutable_input(0).tensor;
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
|
||||||
|
test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
|
||||||
|
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
|
||||||
|
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||||
|
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||||
|
Status s = RunOpKernel();
|
||||||
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
|
.contains("Invalid indices: [2,0] = [99] is not in [0, 5)"))
|
||||||
|
<< s;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
|
||||||
|
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||||
|
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||||
|
Status s = RunOpKernel();
|
||||||
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
|
.contains("The outermost dimension of updates and indices "
|
||||||
|
"must match. Got indices.shape [1,3,1], "
|
||||||
|
"updates.shape [3,3]"))
|
||||||
|
<< s;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||||
|
AddInputFromArray<float>(
|
||||||
|
TensorShape({3, 4}),
|
||||||
|
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
|
||||||
|
Status s = RunOpKernel();
|
||||||
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
|
.contains("Must have updates.shape = indices.shape[0] + "
|
||||||
|
"params_shape[IXDIM:], got"))
|
||||||
|
|
||||||
|
<< s;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
|
||||||
|
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||||
|
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||||
|
{100, 101, 102, 10000, 10001, 10002});
|
||||||
|
Status s = RunOpKernel();
|
||||||
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
|
.contains("The outermost dimension of updates and indices "
|
||||||
|
"must match. Got "))
|
||||||
|
<< s;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
|
||||||
|
public:
|
||||||
|
virtual void TestBody() {}
|
||||||
|
void MakeBenchmarkOp(const char* op, DataType index_type) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("myop", op)
|
||||||
|
.Input(FakeInput(DT_FLOAT_REF))
|
||||||
|
.Input(FakeInput(index_type))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_CHECK_OK(InitOp());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Index>
|
||||||
|
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
|
||||||
|
testing::StopTiming();
|
||||||
|
const int kRows = 10000000 / embedding_size;
|
||||||
|
std::vector<float> values;
|
||||||
|
values.reserve(kRows);
|
||||||
|
for (int i = 0; i < kRows * embedding_size; i++) {
|
||||||
|
values.push_back(i);
|
||||||
|
}
|
||||||
|
const int kNumUpdates = 1000;
|
||||||
|
random::PhiloxRandom philox(301, 17);
|
||||||
|
random::SimplePhilox rnd(&philox);
|
||||||
|
std::vector<Index> indices;
|
||||||
|
std::vector<float> updates;
|
||||||
|
for (int i = 0; i < kNumUpdates; i++) {
|
||||||
|
indices.push_back(rnd.Uniform(kRows));
|
||||||
|
for (int j = 0; j < embedding_size; j++) {
|
||||||
|
updates.push_back(i * 10 + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ScatterNdUpdateBM bm;
|
||||||
|
bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
|
||||||
|
bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
|
||||||
|
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
|
||||||
|
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
|
||||||
|
updates);
|
||||||
|
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||||
|
iters);
|
||||||
|
testing::StartTiming();
|
||||||
|
while (iters-- > 0) {
|
||||||
|
Status s = bm.RunOpKernel();
|
||||||
|
}
|
||||||
|
testing::StopTiming();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
|
||||||
|
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
|
||||||
|
}
|
||||||
|
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
|
||||||
|
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
|
||||||
|
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
|
||||||
|
}
|
||||||
|
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
|
||||||
|
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_ScatterNdUpdateInt32)
|
||||||
|
->Arg(1)
|
||||||
|
->Arg(10)
|
||||||
|
->Arg(64)
|
||||||
|
->Arg(256)
|
||||||
|
->Arg(1024);
|
||||||
|
BENCHMARK(BM_ScatterNdUpdateInt64)
|
||||||
|
->Arg(1)
|
||||||
|
->Arg(10)
|
||||||
|
->Arg(64)
|
||||||
|
->Arg(256)
|
||||||
|
->Arg(1024);
|
||||||
|
|
||||||
|
BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||||
|
BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -88,16 +88,12 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
|||||||
|
|
||||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||||
std::function<void(int64, int64)> fn) {
|
std::function<void(int64, int64)> fn) {
|
||||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
|
||||||
CHECK_GE(total, 0);
|
CHECK_GE(total, 0);
|
||||||
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
||||||
Eigen::ThreadPoolDevice device(this, this->NumThreads());
|
Eigen::ThreadPoolDevice device(this, this->NumThreads());
|
||||||
device.parallelFor(
|
device.parallelFor(
|
||||||
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
||||||
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
||||||
#else
|
|
||||||
CHECK(0); // should not be used with the old thread pool
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -57,7 +57,6 @@ TEST(ThreadPool, DoWork) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
|
||||||
TEST(ThreadPool, ParallelFor) {
|
TEST(ThreadPool, ParallelFor) {
|
||||||
// Make ParallelFor use as many threads as possible.
|
// Make ParallelFor use as many threads as possible.
|
||||||
int64 kHugeCost = 1 << 30;
|
int64 kHugeCost = 1 << 30;
|
||||||
@ -80,7 +79,6 @@ TEST(ThreadPool, ParallelFor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
static void BM_Sequential(int iters) {
|
static void BM_Sequential(int iters) {
|
||||||
ThreadPool pool(Env::Default(), "test", kNumThreads);
|
ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/scanner.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -49,11 +50,14 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the parts of the path, split on the final "/". If there is no
|
// Return the parts of the URI, split on the final "/" in the path. If there is
|
||||||
// "/" in the path, the first part of the output is empty and the second
|
// no "/" in the path, the first part of the output is the scheme and host, and
|
||||||
// is the input. If the only "/" in the path is the first character, it is
|
// the second is the path. If the only "/" in the path is the first character,
|
||||||
// the first part of the output.
|
// it is included in the first part of the output.
|
||||||
std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
|
std::pair<StringPiece, StringPiece> SplitPath(StringPiece uri) {
|
||||||
|
StringPiece scheme, host, path;
|
||||||
|
ParseURI(uri, &scheme, &host, &path);
|
||||||
|
|
||||||
auto pos = path.rfind('/');
|
auto pos = path.rfind('/');
|
||||||
#ifdef PLATFORM_WINDOWS
|
#ifdef PLATFORM_WINDOWS
|
||||||
if (pos == StringPiece::npos)
|
if (pos == StringPiece::npos)
|
||||||
@ -61,15 +65,17 @@ std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
|
|||||||
#endif
|
#endif
|
||||||
// Handle the case with no '/' in 'path'.
|
// Handle the case with no '/' in 'path'.
|
||||||
if (pos == StringPiece::npos)
|
if (pos == StringPiece::npos)
|
||||||
return std::make_pair(StringPiece(path.data(), 0), path);
|
return std::make_pair(StringPiece(uri.begin(), host.end() - uri.begin()),
|
||||||
|
path);
|
||||||
|
|
||||||
// Handle the case with a single leading '/' in 'path'.
|
// Handle the case with a single leading '/' in 'path'.
|
||||||
if (pos == 0)
|
if (pos == 0)
|
||||||
return std::make_pair(StringPiece(path.data(), 1),
|
return std::make_pair(
|
||||||
StringPiece(path.data() + 1, path.size() - 1));
|
StringPiece(uri.begin(), path.begin() + 1 - uri.begin()),
|
||||||
|
StringPiece(path.data() + 1, path.size() - 1));
|
||||||
|
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
StringPiece(path.data(), pos),
|
StringPiece(uri.begin(), path.begin() + pos - uri.begin()),
|
||||||
StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
|
StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,5 +191,42 @@ string CleanPath(StringPiece unclean_path) {
|
|||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
|
||||||
|
StringPiece* path) {
|
||||||
|
// 0. Parse scheme
|
||||||
|
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
|
||||||
|
// TODO(keveman): Allow "+" and "-" in the scheme.
|
||||||
|
if (!strings::Scanner(remaining)
|
||||||
|
.One(strings::Scanner::LETTER)
|
||||||
|
.Many(strings::Scanner::LETTER_DIGIT_DOT)
|
||||||
|
.StopCapture()
|
||||||
|
.OneLiteral("://")
|
||||||
|
.GetResult(&remaining, scheme)) {
|
||||||
|
// If there's no scheme, assume the entire string is a path.
|
||||||
|
*scheme = StringPiece(remaining.begin(), 0);
|
||||||
|
*host = StringPiece(remaining.begin(), 0);
|
||||||
|
*path = remaining;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Parse host
|
||||||
|
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
|
||||||
|
// No path, so the rest of the URI is the host.
|
||||||
|
*host = remaining;
|
||||||
|
*path = StringPiece(remaining.end(), 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. The rest is the path
|
||||||
|
*path = remaining;
|
||||||
|
}
|
||||||
|
|
||||||
|
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
|
||||||
|
if (scheme.empty()) {
|
||||||
|
return path.ToString();
|
||||||
|
}
|
||||||
|
return strings::StrCat(scheme, "://", host, path);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -74,6 +74,21 @@ StringPiece Extension(StringPiece path);
|
|||||||
// string manipulation, completely independent of process state.
|
// string manipulation, completely independent of process state.
|
||||||
string CleanPath(StringPiece path);
|
string CleanPath(StringPiece path);
|
||||||
|
|
||||||
|
// Populates the scheme, host, and path from a URI. scheme, host, and path are
|
||||||
|
// guaranteed by this function to point into the contents of uri, even if
|
||||||
|
// empty.
|
||||||
|
//
|
||||||
|
// Corner cases:
|
||||||
|
// - If the URI is invalid, scheme and host are set to empty strings and the
|
||||||
|
// passed string is assumed to be a path
|
||||||
|
// - If the URI omits the path (e.g. file://host), then the path is left empty.
|
||||||
|
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
|
||||||
|
StringPiece* path);
|
||||||
|
|
||||||
|
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
|
||||||
|
// return the path.
|
||||||
|
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -45,6 +45,8 @@ TEST(PathTest, IsAbsolutePath) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(PathTest, Dirname) {
|
TEST(PathTest, Dirname) {
|
||||||
|
EXPECT_EQ("hdfs://127.0.0.1:9000/",
|
||||||
|
Dirname("hdfs://127.0.0.1:9000/train.csv.tfrecords"));
|
||||||
EXPECT_EQ("/hello", Dirname("/hello/"));
|
EXPECT_EQ("/hello", Dirname("/hello/"));
|
||||||
EXPECT_EQ("/", Dirname("/hello"));
|
EXPECT_EQ("/", Dirname("/hello"));
|
||||||
EXPECT_EQ("hello", Dirname("hello/world"));
|
EXPECT_EQ("hello", Dirname("hello/world"));
|
||||||
@ -97,5 +99,47 @@ TEST(PathTest, CleanPath) {
|
|||||||
EXPECT_EQ("../../bar", CleanPath("foo/../../../bar"));
|
EXPECT_EQ("../../bar", CleanPath("foo/../../../bar"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
|
||||||
|
do { \
|
||||||
|
StringPiece u(uri); \
|
||||||
|
StringPiece s, h, p; \
|
||||||
|
ParseURI(u, &s, &h, &p); \
|
||||||
|
EXPECT_EQ(scheme, s.ToString()); \
|
||||||
|
EXPECT_EQ(host, h.ToString()); \
|
||||||
|
EXPECT_EQ(path, p.ToString()); \
|
||||||
|
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
|
||||||
|
EXPECT_LE(u.begin(), s.begin()); \
|
||||||
|
EXPECT_GE(u.end(), s.begin()); \
|
||||||
|
EXPECT_LE(u.begin(), s.end()); \
|
||||||
|
EXPECT_GE(u.end(), s.end()); \
|
||||||
|
EXPECT_LE(u.begin(), h.begin()); \
|
||||||
|
EXPECT_GE(u.end(), h.begin()); \
|
||||||
|
EXPECT_LE(u.begin(), h.end()); \
|
||||||
|
EXPECT_GE(u.end(), h.end()); \
|
||||||
|
EXPECT_LE(u.begin(), p.begin()); \
|
||||||
|
EXPECT_GE(u.end(), p.begin()); \
|
||||||
|
EXPECT_LE(u.begin(), p.end()); \
|
||||||
|
EXPECT_GE(u.end(), p.end()); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
TEST(PathTest, CreateParseURI) {
|
||||||
|
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
|
||||||
|
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
|
||||||
|
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
|
||||||
|
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
|
||||||
|
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
|
||||||
|
"/usr/local/foo");
|
||||||
|
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
|
||||||
|
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
|
||||||
|
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
|
||||||
|
EXPECT_PARSE_URI("file:", "", "", "file:");
|
||||||
|
EXPECT_PARSE_URI("file:/", "", "", "file:/");
|
||||||
|
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
|
||||||
|
"localhost:8020", "/path/to/file");
|
||||||
|
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
|
||||||
|
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
|
||||||
|
}
|
||||||
|
#undef EXPECT_PARSE_URI
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -4387,6 +4387,83 @@ output_min: This value is copied from input_min.
|
|||||||
output_max: This value is copied from input_max.
|
output_max: This value is copied from input_max.
|
||||||
)Doc");
|
)Doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNd")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Input("shape: Tindices")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices.
|
||||||
|
This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.
|
||||||
|
|
||||||
|
TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.
|
||||||
|
|
||||||
|
`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `shape`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `shape`.
|
||||||
|
|
||||||
|
`updates` is Tensor of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
|
||||||
|
|
||||||
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
<img style="width:100%" src="../../images/ScatterNd1.png" alt>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
In Python, this scatter operation would look like this:
|
||||||
|
|
||||||
|
indices = tf.constant([[4], [3], [1], [7]])
|
||||||
|
updates = tf.constant([9, 10, 11, 12])
|
||||||
|
shape = tf.constant([8])
|
||||||
|
scatter = tf.scatter_nd(indices, updates, shape)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(scatter)
|
||||||
|
|
||||||
|
The resulting tensor would look like this:
|
||||||
|
|
||||||
|
[0, 11, 0, 10, 9, 0, 0, 12]
|
||||||
|
|
||||||
|
We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
|
||||||
|
|
||||||
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
<img style="width:100%" src="../../images/ScatterNd2.png" alt>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
In Python, this scatter operation would look like this:
|
||||||
|
|
||||||
|
indices = tf.constant([[0], [2]])
|
||||||
|
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||||
|
[7, 7, 7, 7], [8, 8, 8, 8]],
|
||||||
|
[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||||
|
[7, 7, 7, 7], [8, 8, 8, 8]]])
|
||||||
|
shape = tf.constant([4, 4, 4])
|
||||||
|
scatter = tf.scatter_nd(indices, updates, shape)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(scatter)
|
||||||
|
|
||||||
|
The resulting tensor would look like this:
|
||||||
|
|
||||||
|
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||||
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||||
|
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref.
|
||||||
|
shape: A vector. The shape of the resulting tensor.
|
||||||
|
output: A new tensor with the given shape and updates applied according to the indices.)doc");
|
||||||
|
|
||||||
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
||||||
.Attr("min: float = -6.0")
|
.Attr("min: float = -6.0")
|
||||||
.Attr("max: float = 6.0")
|
.Attr("max: float = 6.0")
|
||||||
@ -4409,6 +4486,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
|
|||||||
.Input("gradients: float")
|
.Input("gradients: float")
|
||||||
.Input("inputs: float")
|
.Input("inputs: float")
|
||||||
.Output("backprops: float")
|
.Output("backprops: float")
|
||||||
|
.SetShapeFn(shape_inference::UnchangedShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Compute gradients for a FakeQuantWithMinMaxArgs operation.
|
Compute gradients for a FakeQuantWithMinMaxArgs operation.
|
||||||
|
|
||||||
@ -4450,6 +4528,21 @@ REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
|
|||||||
.Output("backprops_wrt_input: float")
|
.Output("backprops_wrt_input: float")
|
||||||
.Output("backprop_wrt_min: float")
|
.Output("backprop_wrt_min: float")
|
||||||
.Output("backprop_wrt_max: float")
|
.Output("backprop_wrt_max: float")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// gradients and inputs are same size.
|
||||||
|
ShapeHandle inputs;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
|
||||||
|
|
||||||
|
// min and max are scalars
|
||||||
|
ShapeHandle min_max;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
|
||||||
|
|
||||||
|
c->set_output(0, inputs);
|
||||||
|
c->set_output(1, min_max);
|
||||||
|
c->set_output(2, min_max);
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Compute gradients for a FakeQuantWithMinMaxVars operation.
|
Compute gradients for a FakeQuantWithMinMaxVars operation.
|
||||||
|
|
||||||
@ -4503,6 +4596,24 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
|
|||||||
.Output("backprops_wrt_input: float")
|
.Output("backprops_wrt_input: float")
|
||||||
.Output("backprop_wrt_min: float")
|
.Output("backprop_wrt_min: float")
|
||||||
.Output("backprop_wrt_max: float")
|
.Output("backprop_wrt_max: float")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle inputs;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
|
||||||
|
|
||||||
|
ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
|
||||||
|
|
||||||
|
ShapeHandle min_max;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
|
||||||
|
|
||||||
|
c->set_output(0, inputs);
|
||||||
|
c->set_output(1, min_max);
|
||||||
|
c->set_output(2, min_max);
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
|
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
|
||||||
|
|
||||||
|
@ -1533,4 +1533,23 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
|
|||||||
INFER_ERROR("must be equal", op, "[5];[4];[?]");
|
INFER_ERROR("must be equal", op, "[5];[4];[?]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
|
||||||
|
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
|
||||||
|
|
||||||
|
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
|
||||||
|
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
|
||||||
|
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
|
||||||
|
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
|
||||||
|
|
||||||
|
// Rank check vectors.
|
||||||
|
INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
|
||||||
|
INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
|
||||||
|
INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
|
||||||
|
INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
|
||||||
|
|
||||||
|
// Vectors must match each other, and match last dim of input.
|
||||||
|
INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
|
||||||
|
INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -24732,6 +24732,321 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNd"
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "shape"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdAdd"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdDiv"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdMul"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdSub"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdUpdate"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "ScatterSub"
|
name: "ScatterSub"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -15427,6 +15427,362 @@ op {
|
|||||||
summary: "Multiplies sparse updates into a variable reference."
|
summary: "Multiplies sparse updates into a variable reference."
|
||||||
description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] *= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] *= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions multiply.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`."
|
description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] *= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] *= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions multiply.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`."
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNd"
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: "A vector. The shape of the resulting tensor."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "A new tensor with the given shape and updates applied according to the indices."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary: "Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices."
|
||||||
|
description: "This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.\n\nTODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.\n\n`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `shape`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `shape`.\n\n`updates` is Tensor of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].\n```\n\nThe simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd1.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n shape = tf.constant([8])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [0, 11, 0, 10, 9, 0, 0, 12]\n\nWe can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterNd2.png\" alt>\n</div>\n\nIn Python, this scatter operation would look like this:\n\n indices = tf.constant([[0], [2]])\n updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]],\n [[5, 5, 5, 5], [6, 6, 6, 6],\n [7, 7, 7, 7], [8, 8, 8, 8]]])\n shape = tf.constant([4, 4, 4])\n scatter = tf.scatter_nd(indices, updates, shape)\n with tf.Session() as sess:\n print sess.run(scatter)\n\nThe resulting tensor would look like this:\n\n [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]"
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdAdd"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
description: "A mutable Tensor. Should be from a Variable node."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||||
|
}
|
||||||
|
summary: "Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`."
|
||||||
|
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n add = tf.scatter_nd_add(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(add)\n\nThe resulting update to ref would look like this:\n\n [1, 13, 3, 14, 14, 6, 7, 20]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdDiv"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
description: "A mutable Tensor. Should be from a Variable node."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||||
|
}
|
||||||
|
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||||
|
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:\n\n ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([2, 3, 4, 5])\n sub = tf.scatter_nd_div(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [10, 5, 30, 13, 25, 60, 70, 16]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdMul"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
description: "A mutable Tensor. Should be from a Variable node."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||||
|
}
|
||||||
|
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||||
|
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_mul(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, 22, 3, 40, 45, 6, 7, 96]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdSub"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
description: "A mutable Tensor. Should be from a Variable node."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||||
|
}
|
||||||
|
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||||
|
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_sub(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, -9, 3, -6, -4, 6, 7, -4]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "ScatterNdUpdate"
|
||||||
|
input_arg {
|
||||||
|
name: "ref"
|
||||||
|
description: "A mutable Tensor. Should be from a Variable node."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "updates"
|
||||||
|
description: "A Tensor. Must have the same type as ref. A tensor of updated values to add to ref."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output_ref"
|
||||||
|
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||||
|
type_attr: "T"
|
||||||
|
is_ref: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "use_locking"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||||
|
}
|
||||||
|
summary: "Applies sparse `updates` to individual values or slices within a given variable according to `indices`."
|
||||||
|
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1] ,[7]])\n updates = tf.constant([9, 10, 11, 12])\n update = tf.scatter_nd_update(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(update)\n\nThe resulting update to ref would look like this:\n\n [1, 11, 3, 10, 9, 6, 7, 12]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "ScatterSub"
|
name: "ScatterSub"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -445,6 +445,241 @@ use_locking: If True, the operation will be protected by a lock;
|
|||||||
otherwise the behavior is undefined, but may exhibit less contention.
|
otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNdUpdate")
|
||||||
|
.Input("ref: Ref(T)")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Output("output_ref: Ref(T)")
|
||||||
|
.Attr("T: type")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Attr("use_locking: bool = true")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`.
|
||||||
|
|
||||||
|
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `ref`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `ref`.
|
||||||
|
|
||||||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
|
||||||
|
|
||||||
|
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
indices = tf.constant([[4], [3], [1] ,[7]])
|
||||||
|
updates = tf.constant([9, 10, 11, 12])
|
||||||
|
update = tf.scatter_nd_update(ref, indices, updates)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(update)
|
||||||
|
|
||||||
|
The resulting update to ref would look like this:
|
||||||
|
|
||||||
|
[1, 11, 3, 10, 9, 6, 7, 12]
|
||||||
|
|
||||||
|
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||||
|
|
||||||
|
ref: A mutable Tensor. Should be from a Variable node.
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||||
|
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
|
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNdAdd")
|
||||||
|
.Input("ref: Ref(T)")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Output("output_ref: Ref(T)")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Attr("use_locking: bool = false")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`.
|
||||||
|
|
||||||
|
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `ref`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `ref`.
|
||||||
|
|
||||||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:
|
||||||
|
|
||||||
|
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
indices = tf.constant([[4], [3], [1], [7]])
|
||||||
|
updates = tf.constant([9, 10, 11, 12])
|
||||||
|
add = tf.scatter_nd_add(ref, indices, updates)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(add)
|
||||||
|
|
||||||
|
The resulting update to ref would look like this:
|
||||||
|
|
||||||
|
[1, 13, 3, 14, 14, 6, 7, 20]
|
||||||
|
|
||||||
|
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||||
|
|
||||||
|
ref: A mutable Tensor. Should be from a Variable node.
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||||
|
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
|
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNdSub")
|
||||||
|
.Input("ref: Ref(T)")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Output("output_ref: Ref(T)")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Attr("use_locking: bool = false")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||||
|
|
||||||
|
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `ref`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `ref`.
|
||||||
|
|
||||||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:
|
||||||
|
|
||||||
|
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
indices = tf.constant([[4], [3], [1], [7]])
|
||||||
|
updates = tf.constant([9, 10, 11, 12])
|
||||||
|
sub = tf.scatter_nd_sub(ref, indices, updates)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(sub)
|
||||||
|
|
||||||
|
The resulting update to ref would look like this:
|
||||||
|
|
||||||
|
[1, -9, 3, -6, -4, 6, 7, -4]
|
||||||
|
|
||||||
|
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||||
|
|
||||||
|
ref: A mutable Tensor. Should be from a Variable node.
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||||
|
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
|
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNdMul")
|
||||||
|
.Input("ref: Ref(T)")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Output("output_ref: Ref(T)")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Attr("use_locking: bool = false")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||||
|
|
||||||
|
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `ref`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `ref`.
|
||||||
|
|
||||||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:
|
||||||
|
|
||||||
|
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
indices = tf.constant([[4], [3], [1], [7]])
|
||||||
|
updates = tf.constant([9, 10, 11, 12])
|
||||||
|
sub = tf.scatter_nd_mul(ref, indices, updates)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(sub)
|
||||||
|
|
||||||
|
The resulting update to ref would look like this:
|
||||||
|
|
||||||
|
[1, 22, 3, 40, 45, 6, 7, 96]
|
||||||
|
|
||||||
|
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||||
|
|
||||||
|
ref: A mutable Tensor. Should be from a Variable node.
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||||
|
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
|
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ScatterNdDiv")
|
||||||
|
.Input("ref: Ref(T)")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Output("output_ref: Ref(T)")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Attr("use_locking: bool = false")
|
||||||
|
.Doc(
|
||||||
|
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||||
|
|
||||||
|
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `ref`.
|
||||||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||||
|
|
||||||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||||
|
dimension of `ref`.
|
||||||
|
|
||||||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:
|
||||||
|
|
||||||
|
ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
|
||||||
|
indices = tf.constant([[4], [3], [1], [7]])
|
||||||
|
updates = tf.constant([2, 3, 4, 5])
|
||||||
|
sub = tf.scatter_nd_div(ref, indices, updates)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
print sess.run(sub)
|
||||||
|
|
||||||
|
The resulting update to ref would look like this:
|
||||||
|
|
||||||
|
[10, 5, 30, 13, 25, 60, 70, 16]
|
||||||
|
|
||||||
|
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||||
|
|
||||||
|
ref: A mutable Tensor. Should be from a Variable node.
|
||||||
|
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||||
|
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||||
|
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||||
|
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||||
|
|
||||||
REGISTER_OP("CountUpTo")
|
REGISTER_OP("CountUpTo")
|
||||||
.Input("ref: Ref(T)")
|
.Input("ref: Ref(T)")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
|
@ -81,7 +81,7 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
|
|||||||
return errors::Internal("bucket and object cannot be null.");
|
return errors::Internal("bucket and object cannot be null.");
|
||||||
}
|
}
|
||||||
StringPiece scheme, bucketp, objectp;
|
StringPiece scheme, bucketp, objectp;
|
||||||
ParseURI(fname, &scheme, &bucketp, &objectp);
|
io::ParseURI(fname, &scheme, &bucketp, &objectp);
|
||||||
if (scheme != "gs") {
|
if (scheme != "gs") {
|
||||||
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
|
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
|
||||||
fname);
|
fname);
|
||||||
|
@ -76,16 +76,24 @@ cc_library(
|
|||||||
name = "platformlib",
|
name = "platformlib",
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":gif",
|
||||||
|
":jpeg",
|
||||||
"//tensorflow/core:protos_cc",
|
"//tensorflow/core:protos_cc",
|
||||||
"@com_googlesource_code_re2//:re2",
|
"@com_googlesource_code_re2//:re2",
|
||||||
"@farmhash_archive//:farmhash",
|
"@farmhash_archive//:farmhash",
|
||||||
"@gif_archive//:gif",
|
|
||||||
"@highwayhash//:sip_hash",
|
"@highwayhash//:sip_hash",
|
||||||
"@jpeg_archive//:jpeg",
|
|
||||||
"@png_archive//:png",
|
"@png_archive//:png",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gif",
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"@gif_archive//:gif",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "jpeg",
|
name = "jpeg",
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
|
@ -70,7 +70,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
|
|||||||
|
|
||||||
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
|
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(fname, &scheme, &host, &path);
|
io::ParseURI(fname, &scheme, &host, &path);
|
||||||
FileSystem* file_system = file_system_registry_->Lookup(scheme.ToString());
|
FileSystem* file_system = file_system_registry_->Lookup(scheme.ToString());
|
||||||
if (!file_system) {
|
if (!file_system) {
|
||||||
return errors::Unimplemented("File system scheme ", scheme,
|
return errors::Unimplemented("File system scheme ", scheme,
|
||||||
|
@ -229,35 +229,6 @@ TEST_F(DefaultEnvTest, LocalFileSystem) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define EXPECT_PARSE_URI(uri, scheme, host, path) \
|
|
||||||
do { \
|
|
||||||
StringPiece s, h, p; \
|
|
||||||
ParseURI(uri, &s, &h, &p); \
|
|
||||||
EXPECT_EQ(scheme, s.ToString()); \
|
|
||||||
EXPECT_EQ(host, h.ToString()); \
|
|
||||||
EXPECT_EQ(path, p.ToString()); \
|
|
||||||
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
TEST_F(DefaultEnvTest, CreateParseURI) {
|
|
||||||
EXPECT_PARSE_URI("http://foo", "http", "foo", "");
|
|
||||||
EXPECT_PARSE_URI("/encrypted/://foo", "", "", "/encrypted/://foo");
|
|
||||||
EXPECT_PARSE_URI("/usr/local/foo", "", "", "/usr/local/foo");
|
|
||||||
EXPECT_PARSE_URI("file:///usr/local/foo", "file", "", "/usr/local/foo");
|
|
||||||
EXPECT_PARSE_URI("local.file:///usr/local/foo", "local.file", "",
|
|
||||||
"/usr/local/foo");
|
|
||||||
EXPECT_PARSE_URI("a-b:///foo", "", "", "a-b:///foo");
|
|
||||||
EXPECT_PARSE_URI(":///foo", "", "", ":///foo");
|
|
||||||
EXPECT_PARSE_URI("9dfd:///foo", "", "", "9dfd:///foo");
|
|
||||||
EXPECT_PARSE_URI("file:", "", "", "file:");
|
|
||||||
EXPECT_PARSE_URI("file:/", "", "", "file:/");
|
|
||||||
EXPECT_PARSE_URI("hdfs://localhost:8020/path/to/file", "hdfs",
|
|
||||||
"localhost:8020", "/path/to/file");
|
|
||||||
EXPECT_PARSE_URI("hdfs://localhost:8020", "hdfs", "localhost:8020", "");
|
|
||||||
EXPECT_PARSE_URI("hdfs://localhost:8020/", "hdfs", "localhost:8020", "/");
|
|
||||||
}
|
|
||||||
#undef EXPECT_PARSE_URI
|
|
||||||
|
|
||||||
TEST_F(DefaultEnvTest, SleepForMicroseconds) {
|
TEST_F(DefaultEnvTest, SleepForMicroseconds) {
|
||||||
const int64 start = env_->NowMicros();
|
const int64 start = env_->NowMicros();
|
||||||
const int64 sleep_time = 1e6 + 5e5;
|
const int64 sleep_time = 1e6 + 5e5;
|
||||||
@ -274,14 +245,14 @@ class TmpDirFileSystem : public NullFileSystem {
|
|||||||
public:
|
public:
|
||||||
bool FileExists(const string& dir) override {
|
bool FileExists(const string& dir) override {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(dir, &scheme, &host, &path);
|
io::ParseURI(dir, &scheme, &host, &path);
|
||||||
if (path.empty()) return false;
|
if (path.empty()) return false;
|
||||||
return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
|
return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDir(const string& dir) override {
|
Status CreateDir(const string& dir) override {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(dir, &scheme, &host, &path);
|
io::ParseURI(dir, &scheme, &host, &path);
|
||||||
if (scheme != "tmpdirfs") {
|
if (scheme != "tmpdirfs") {
|
||||||
return errors::FailedPrecondition("scheme must be tmpdirfs");
|
return errors::FailedPrecondition("scheme must be tmpdirfs");
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/scanner.h"
|
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
@ -79,43 +78,6 @@ WritableFile::~WritableFile() {}
|
|||||||
|
|
||||||
FileSystemRegistry::~FileSystemRegistry() {}
|
FileSystemRegistry::~FileSystemRegistry() {}
|
||||||
|
|
||||||
void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
|
|
||||||
StringPiece* path) {
|
|
||||||
// 0. Parse scheme
|
|
||||||
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
|
|
||||||
// TODO(keveman): Allow "+" and "-" in the scheme.
|
|
||||||
if (!strings::Scanner(remaining)
|
|
||||||
.One(strings::Scanner::LETTER)
|
|
||||||
.Many(strings::Scanner::LETTER_DIGIT_DOT)
|
|
||||||
.StopCapture()
|
|
||||||
.OneLiteral("://")
|
|
||||||
.GetResult(&remaining, scheme)) {
|
|
||||||
// If there's no scheme, assume the entire string is a path.
|
|
||||||
scheme->clear();
|
|
||||||
host->clear();
|
|
||||||
*path = remaining;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Parse host
|
|
||||||
if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) {
|
|
||||||
// No path, so the rest of the URI is the host.
|
|
||||||
*host = remaining;
|
|
||||||
path->clear();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. The rest is the path
|
|
||||||
*path = remaining;
|
|
||||||
}
|
|
||||||
|
|
||||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
|
|
||||||
if (scheme.empty()) {
|
|
||||||
return path.ToString();
|
|
||||||
}
|
|
||||||
return strings::StrCat(scheme, "://", host, path);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FileSystem::GetMatchingPaths(const string& pattern,
|
Status FileSystem::GetMatchingPaths(const string& pattern,
|
||||||
std::vector<string>* results) {
|
std::vector<string>* results) {
|
||||||
results->clear();
|
results->clear();
|
||||||
@ -237,9 +199,9 @@ Status FileSystem::DeleteRecursively(const string& dirname,
|
|||||||
|
|
||||||
Status FileSystem::RecursivelyCreateDir(const string& dirname) {
|
Status FileSystem::RecursivelyCreateDir(const string& dirname) {
|
||||||
StringPiece scheme, host, remaining_dir;
|
StringPiece scheme, host, remaining_dir;
|
||||||
ParseURI(dirname, &scheme, &host, &remaining_dir);
|
io::ParseURI(dirname, &scheme, &host, &remaining_dir);
|
||||||
std::vector<StringPiece> sub_dirs;
|
std::vector<StringPiece> sub_dirs;
|
||||||
while (!FileExists(CreateURI(scheme, host, remaining_dir)) &&
|
while (!FileExists(io::CreateURI(scheme, host, remaining_dir)) &&
|
||||||
!remaining_dir.empty()) {
|
!remaining_dir.empty()) {
|
||||||
// Basename returns "" for / ending dirs.
|
// Basename returns "" for / ending dirs.
|
||||||
if (!remaining_dir.ends_with("/")) {
|
if (!remaining_dir.ends_with("/")) {
|
||||||
@ -255,7 +217,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
|
|||||||
string built_path = remaining_dir.ToString();
|
string built_path = remaining_dir.ToString();
|
||||||
for (const StringPiece sub_dir : sub_dirs) {
|
for (const StringPiece sub_dir : sub_dirs) {
|
||||||
built_path = io::JoinPath(built_path, sub_dir);
|
built_path = io::JoinPath(built_path, sub_dir);
|
||||||
TF_RETURN_IF_ERROR(CreateDir(CreateURI(scheme, host, built_path)));
|
TF_RETURN_IF_ERROR(CreateDir(io::CreateURI(scheme, host, built_path)));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -287,19 +287,6 @@ class FileSystemRegistry {
|
|||||||
std::vector<string>* schemes) = 0;
|
std::vector<string>* schemes) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Populates the scheme, host, and path from a URI.
|
|
||||||
//
|
|
||||||
// Corner cases:
|
|
||||||
// - If the URI is invalid, scheme and host are set to empty strings and the
|
|
||||||
// passed string is assumed to be a path
|
|
||||||
// - If the URI omits the path (e.g. file://host), then the path is left empty.
|
|
||||||
void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host,
|
|
||||||
StringPiece* path);
|
|
||||||
|
|
||||||
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
|
|
||||||
// return the path.
|
|
||||||
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_
|
#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_
|
||||||
|
@ -112,7 +112,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
|
|||||||
|
|
||||||
void ParsePath(const string& name, string* parsed_path) {
|
void ParsePath(const string& name, string* parsed_path) {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(name, &scheme, &host, &path);
|
io::ParseURI(name, &scheme, &host, &path);
|
||||||
ASSERT_EQ(scheme, "ipfs");
|
ASSERT_EQ(scheme, "ipfs");
|
||||||
ASSERT_EQ(host, "solarsystem");
|
ASSERT_EQ(host, "solarsystem");
|
||||||
path.Consume("/");
|
path.Consume("/");
|
||||||
|
@ -126,7 +126,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
|
|||||||
TF_RETURN_IF_ERROR(hdfs_->status());
|
TF_RETURN_IF_ERROR(hdfs_->status());
|
||||||
|
|
||||||
StringPiece scheme, namenode, path;
|
StringPiece scheme, namenode, path;
|
||||||
ParseURI(fname, &scheme, &namenode, &path);
|
io::ParseURI(fname, &scheme, &namenode, &path);
|
||||||
const string nn = namenode.ToString();
|
const string nn = namenode.ToString();
|
||||||
|
|
||||||
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
|
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
|
||||||
@ -144,7 +144,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
|
|||||||
|
|
||||||
string HadoopFileSystem::TranslateName(const string& name) const {
|
string HadoopFileSystem::TranslateName(const string& name) const {
|
||||||
StringPiece scheme, namenode, path;
|
StringPiece scheme, namenode, path;
|
||||||
ParseURI(name, &scheme, &namenode, &path);
|
io::ParseURI(name, &scheme, &namenode, &path);
|
||||||
return path.ToString();
|
return path.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +120,8 @@ class PosixEnv : public Env {
|
|||||||
symbol);
|
symbol);
|
||||||
}
|
}
|
||||||
|
|
||||||
string FormatLibraryFileName(const string& name, const string& version) {
|
string FormatLibraryFileName(const string& name,
|
||||||
|
const string& version) override {
|
||||||
return tensorflow::internal::FormatLibraryFileName(name, version);
|
return tensorflow::internal::FormatLibraryFileName(name, version);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
||||||
#define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
#define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -63,7 +64,7 @@ class LocalPosixFileSystem : public PosixFileSystem {
|
|||||||
public:
|
public:
|
||||||
string TranslateName(const string& name) const override {
|
string TranslateName(const string& name) const override {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(name, &scheme, &host, &path);
|
io::ParseURI(name, &scheme, &host, &path);
|
||||||
return path.ToString();
|
return path.ToString();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
||||||
#define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
#define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/file_system.h"
|
#include "tensorflow/core/platform/file_system.h"
|
||||||
|
|
||||||
#ifdef PLATFORM_WINDOWS
|
#ifdef PLATFORM_WINDOWS
|
||||||
@ -68,7 +69,7 @@ class LocalWinFileSystem : public WindowsFileSystem {
|
|||||||
public:
|
public:
|
||||||
string TranslateName(const string& name) const override {
|
string TranslateName(const string& name) const override {
|
||||||
StringPiece scheme, host, path;
|
StringPiece scheme, host, path;
|
||||||
ParseURI(name, &scheme, &host, &path);
|
io::ParseURI(name, &scheme, &host, &path);
|
||||||
return path.ToString();
|
return path.ToString();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -122,6 +122,10 @@ message RunStepRequest {
|
|||||||
|
|
||||||
// Options for the run call.
|
// Options for the run call.
|
||||||
RunOptions options = 5;
|
RunOptions options = 5;
|
||||||
|
|
||||||
|
// Partial run handle (optional). If specified, this will be a partial run
|
||||||
|
// execution, run up to the specified fetches.
|
||||||
|
string partial_run_handle = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RunStepResponse {
|
message RunStepResponse {
|
||||||
@ -133,6 +137,42 @@ message RunStepResponse {
|
|||||||
RunMetadata metadata = 2;
|
RunMetadata metadata = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
//
|
||||||
|
// PartialRunSetup method request/response protos.
|
||||||
|
//
|
||||||
|
// The caller should provide the future partial run feeds, fetches, and targets.
|
||||||
|
// Then the caller can use RunStepRequest with is_partial set to make partial
|
||||||
|
// run calls.
|
||||||
|
//
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
message PartialRunSetupRequest {
|
||||||
|
// REQUIRED: session_handle must be returned by a CreateSession call
|
||||||
|
// to the same master service.
|
||||||
|
string session_handle = 1;
|
||||||
|
|
||||||
|
// Tensors to be fed in future steps.
|
||||||
|
repeated string feed = 2;
|
||||||
|
|
||||||
|
// Fetches. A list of tensor names. The caller expects a tensor to be returned
|
||||||
|
// for each fetch[i] (see RunStepResponse.tensor), for corresponding partial
|
||||||
|
// RunStepRequests. The order of specified fetches does not change the
|
||||||
|
// execution order.
|
||||||
|
repeated string fetch = 3;
|
||||||
|
|
||||||
|
// Target Nodes. A list of node names. The named nodes will be run in future
|
||||||
|
// steps, but their outputs will not be fetched.
|
||||||
|
repeated string target = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PartialRunSetupResponse {
|
||||||
|
// The unique handle corresponding to the ongoing partial run call setup by
|
||||||
|
// the invocation to PartialRunSetup. This handle may be passed to
|
||||||
|
// RunStepRequest to send and receive tensors for this partial run.
|
||||||
|
string partial_run_handle = 1;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
//
|
//
|
||||||
// CloseSession method request/response protos.
|
// CloseSession method request/response protos.
|
||||||
|
@ -91,6 +91,9 @@ service MasterService {
|
|||||||
// Extends a session.
|
// Extends a session.
|
||||||
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
|
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
|
||||||
|
|
||||||
|
// Prepares future partial run calls.
|
||||||
|
rpc PartialRunSetup(PartialRunSetupRequest) returns (PartialRunSetupResponse);
|
||||||
|
|
||||||
// Drives the graph computation.
|
// Drives the graph computation.
|
||||||
rpc RunStep(RunStepRequest) returns (RunStepResponse);
|
rpc RunStep(RunStepRequest) returns (RunStepResponse);
|
||||||
|
|
||||||
|
@ -343,7 +343,11 @@ Status BundleWriter::Finish() {
|
|||||||
status_ = env_->NewWritableFile(MetaFilename(prefix_), &file);
|
status_ = env_->NewWritableFile(MetaFilename(prefix_), &file);
|
||||||
if (!status_.ok()) return status_;
|
if (!status_.ok()) return status_;
|
||||||
{
|
{
|
||||||
table::TableBuilder builder(table::Options(), file.get());
|
// N.B.: the default use of Snappy compression may not be supported on all
|
||||||
|
// platforms (e.g. Android). The metadata file is small, so this is fine.
|
||||||
|
table::Options options;
|
||||||
|
options.compression = table::kNoCompression;
|
||||||
|
table::TableBuilder builder(options, file.get());
|
||||||
// Header entry.
|
// Header entry.
|
||||||
BundleHeaderProto header;
|
BundleHeaderProto header;
|
||||||
header.set_num_shards(1);
|
header.set_num_shards(1);
|
||||||
|
@ -31,12 +31,10 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
|
|||||||
work(0, total);
|
work(0, total);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
|
||||||
if (max_parallelism >= workers->NumThreads()) {
|
if (max_parallelism >= workers->NumThreads()) {
|
||||||
workers->ParallelFor(total, cost_per_unit, work);
|
workers->ParallelFor(total, cost_per_unit, work);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
cost_per_unit = std::max(1LL, cost_per_unit);
|
cost_per_unit = std::max(1LL, cost_per_unit);
|
||||||
// We shard [0, total) into "num_shards" shards.
|
// We shard [0, total) into "num_shards" shards.
|
||||||
// 1 <= num_shards <= num worker threads
|
// 1 <= num_shards <= num worker threads
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?><!--
|
|
||||||
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
-->
|
|
||||||
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
|
||||||
android:layout_width="match_parent"
|
|
||||||
android:layout_height="match_parent">
|
|
||||||
|
|
||||||
<org.tensorflow.demo.AutoFitTextureView
|
|
||||||
android:id="@+id/texture"
|
|
||||||
android:layout_width="wrap_content"
|
|
||||||
android:layout_height="wrap_content"
|
|
||||||
android:layout_alignParentBottom="true"
|
|
||||||
android:layout_alignParentStart="true"
|
|
||||||
android:layout_alignParentTop="true" />
|
|
||||||
|
|
||||||
<org.tensorflow.demo.RecognitionScoreView
|
|
||||||
android:id="@+id/results"
|
|
||||||
android:layout_width="match_parent"
|
|
||||||
android:layout_height="112dp"
|
|
||||||
android:layout_alignParentTop="true" />
|
|
||||||
|
|
||||||
</RelativeLayout>
|
|
@ -22,11 +22,17 @@
|
|||||||
android:layout_width="wrap_content"
|
android:layout_width="wrap_content"
|
||||||
android:layout_height="wrap_content"
|
android:layout_height="wrap_content"
|
||||||
android:layout_alignParentBottom="true" />
|
android:layout_alignParentBottom="true" />
|
||||||
|
|
||||||
<org.tensorflow.demo.RecognitionScoreView
|
<org.tensorflow.demo.RecognitionScoreView
|
||||||
android:id="@+id/results"
|
android:id="@+id/results"
|
||||||
android:layout_width="match_parent"
|
android:layout_width="match_parent"
|
||||||
android:layout_height="112dp"
|
android:layout_height="112dp"
|
||||||
android:layout_alignParentTop="true" />
|
android:layout_alignParentTop="true" />
|
||||||
|
|
||||||
|
<org.tensorflow.demo.OverlayView
|
||||||
|
android:id="@+id/overlay"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
android:layout_alignParentBottom="true" />
|
||||||
|
|
||||||
</RelativeLayout>
|
</RelativeLayout>
|
||||||
|
@ -20,32 +20,72 @@ import android.Manifest;
|
|||||||
import android.app.Activity;
|
import android.app.Activity;
|
||||||
import android.app.Fragment;
|
import android.app.Fragment;
|
||||||
import android.content.pm.PackageManager;
|
import android.content.pm.PackageManager;
|
||||||
|
import android.media.Image.Plane;
|
||||||
|
import android.media.ImageReader.OnImageAvailableListener;
|
||||||
import android.os.Build;
|
import android.os.Build;
|
||||||
import android.os.Bundle;
|
import android.os.Bundle;
|
||||||
|
import android.os.Handler;
|
||||||
|
import android.os.HandlerThread;
|
||||||
|
import android.util.Size;
|
||||||
|
import android.view.MotionEvent;
|
||||||
import android.view.WindowManager;
|
import android.view.WindowManager;
|
||||||
import android.widget.Toast;
|
import android.widget.Toast;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import org.tensorflow.demo.env.Logger;
|
||||||
|
|
||||||
|
public abstract class CameraActivity extends Activity implements OnImageAvailableListener {
|
||||||
|
private static final Logger LOGGER = new Logger();
|
||||||
|
|
||||||
public abstract class CameraActivity extends Activity {
|
|
||||||
private static final int PERMISSIONS_REQUEST = 1;
|
private static final int PERMISSIONS_REQUEST = 1;
|
||||||
|
|
||||||
private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
|
private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
|
||||||
private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
|
private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
|
||||||
|
|
||||||
|
private boolean debug = false;
|
||||||
|
|
||||||
|
private Handler handler;
|
||||||
|
private HandlerThread handlerThread;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void onCreate(final Bundle savedInstanceState) {
|
protected void onCreate(final Bundle savedInstanceState) {
|
||||||
super.onCreate(savedInstanceState);
|
super.onCreate(null);
|
||||||
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
|
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
|
||||||
|
|
||||||
setContentView(R.layout.activity_camera);
|
setContentView(R.layout.activity_camera);
|
||||||
|
|
||||||
if (hasPermission()) {
|
if (hasPermission()) {
|
||||||
if (null == savedInstanceState) {
|
setFragment();
|
||||||
setFragment();
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
requestPermission();
|
requestPermission();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void onResume() {
|
||||||
|
super.onResume();
|
||||||
|
|
||||||
|
handlerThread = new HandlerThread("inference");
|
||||||
|
handlerThread.start();
|
||||||
|
handler = new Handler(handlerThread.getLooper());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void onPause() {
|
||||||
|
super.onPause();
|
||||||
|
handlerThread.quitSafely();
|
||||||
|
try {
|
||||||
|
handlerThread.join();
|
||||||
|
handlerThread = null;
|
||||||
|
handler = null;
|
||||||
|
} catch (final InterruptedException e) {
|
||||||
|
LOGGER.e(e, "Exception!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected synchronized void runInBackground(final Runnable r) {
|
||||||
|
if (handler != null) {
|
||||||
|
handler.post(r);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -82,11 +122,47 @@ public abstract class CameraActivity extends Activity {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void setFragment() {
|
protected void setFragment() {
|
||||||
|
final Fragment fragment = CameraConnectionFragment.newInstance(
|
||||||
|
new CameraConnectionFragment.ConnectionCallback(){
|
||||||
|
@Override
|
||||||
|
public void onPreviewSizeChosen(final Size size, final int rotation) {
|
||||||
|
CameraActivity.this.onPreviewSizeChosen(size, rotation);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
this, getLayoutId(), getDesiredPreviewFrameSize());
|
||||||
|
|
||||||
getFragmentManager()
|
getFragmentManager()
|
||||||
.beginTransaction()
|
.beginTransaction()
|
||||||
.replace(R.id.container, createFragment())
|
.replace(R.id.container, fragment)
|
||||||
.commit();
|
.commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract Fragment createFragment();
|
protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
|
||||||
|
// Because of the variable row stride it's not possible to know in
|
||||||
|
// advance the actual necessary dimensions of the yuv planes.
|
||||||
|
for (int i = 0; i < planes.length; ++i) {
|
||||||
|
final ByteBuffer buffer = planes[i].getBuffer();
|
||||||
|
if (yuvBytes[i] == null) {
|
||||||
|
LOGGER.i("Initializing buffer %d at size %d", i, buffer.capacity());
|
||||||
|
yuvBytes[i] = new byte[buffer.capacity()];
|
||||||
|
}
|
||||||
|
buffer.get(yuvBytes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean onTouchEvent(final MotionEvent event) {
|
||||||
|
if (event.getAction() == MotionEvent.ACTION_DOWN) {
|
||||||
|
debug = !debug;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isDebug() {
|
||||||
|
return debug;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
|
||||||
|
protected abstract int getLayoutId();
|
||||||
|
protected abstract int getDesiredPreviewFrameSize();
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import android.hardware.camera2.CaptureResult;
|
|||||||
import android.hardware.camera2.TotalCaptureResult;
|
import android.hardware.camera2.TotalCaptureResult;
|
||||||
import android.hardware.camera2.params.StreamConfigurationMap;
|
import android.hardware.camera2.params.StreamConfigurationMap;
|
||||||
import android.media.ImageReader;
|
import android.media.ImageReader;
|
||||||
|
import android.media.ImageReader.OnImageAvailableListener;
|
||||||
import android.os.Bundle;
|
import android.os.Bundle;
|
||||||
import android.os.Handler;
|
import android.os.Handler;
|
||||||
import android.os.HandlerThread;
|
import android.os.HandlerThread;
|
||||||
@ -49,9 +50,6 @@ import android.view.TextureView;
|
|||||||
import android.view.View;
|
import android.view.View;
|
||||||
import android.view.ViewGroup;
|
import android.view.ViewGroup;
|
||||||
import android.widget.Toast;
|
import android.widget.Toast;
|
||||||
|
|
||||||
import org.tensorflow.demo.env.Logger;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@ -59,6 +57,7 @@ import java.util.Comparator;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.Semaphore;
|
import java.util.concurrent.Semaphore;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import org.tensorflow.demo.env.Logger;
|
||||||
|
|
||||||
public class CameraConnectionFragment extends Fragment {
|
public class CameraConnectionFragment extends Fragment {
|
||||||
private static final Logger LOGGER = new Logger();
|
private static final Logger LOGGER = new Logger();
|
||||||
@ -69,8 +68,6 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
*/
|
*/
|
||||||
private static final int MINIMUM_PREVIEW_SIZE = 320;
|
private static final int MINIMUM_PREVIEW_SIZE = 320;
|
||||||
|
|
||||||
private ResultsView resultsView;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conversion from screen rotation to JPEG orientation.
|
* Conversion from screen rotation to JPEG orientation.
|
||||||
*/
|
*/
|
||||||
@ -111,6 +108,14 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
|
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback for Activities to use to initialize their data once the
|
||||||
|
* selected preview size is known.
|
||||||
|
*/
|
||||||
|
public interface ConnectionCallback {
|
||||||
|
void onPreviewSizeChosen(Size size, int cameraRotation);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ID of the current {@link CameraDevice}.
|
* ID of the current {@link CameraDevice}.
|
||||||
*/
|
*/
|
||||||
@ -184,16 +189,6 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
*/
|
*/
|
||||||
private Handler backgroundHandler;
|
private Handler backgroundHandler;
|
||||||
|
|
||||||
/**
|
|
||||||
* An additional thread for running inference so as not to block the camera.
|
|
||||||
*/
|
|
||||||
private HandlerThread inferenceThread;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A {@link Handler} for running tasks in the background.
|
|
||||||
*/
|
|
||||||
private Handler inferenceHandler;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An {@link ImageReader} that handles preview frame capture.
|
* An {@link ImageReader} that handles preview frame capture.
|
||||||
*/
|
*/
|
||||||
@ -215,9 +210,10 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
private final Semaphore cameraOpenCloseLock = new Semaphore(1);
|
private final Semaphore cameraOpenCloseLock = new Semaphore(1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A {@link Classifier} object wrapping TensorFlow to pass frames to.
|
* A {@link OnImageAvailableListener} to receive frames as they are available.
|
||||||
*/
|
*/
|
||||||
private final Classifier classifier;
|
private final OnImageAvailableListener imageListener;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The input size in pixels desired by TensorFlow (width and height of a square bitmap).
|
* The input size in pixels desired by TensorFlow (width and height of a square bitmap).
|
||||||
*/
|
*/
|
||||||
@ -228,9 +224,15 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
*/
|
*/
|
||||||
private final int layout;
|
private final int layout;
|
||||||
|
|
||||||
|
|
||||||
|
private final ConnectionCallback cameraConnectionCallback;
|
||||||
|
|
||||||
private CameraConnectionFragment(
|
private CameraConnectionFragment(
|
||||||
final Classifier classifier, final int layout, final int inputSize) {
|
final ConnectionCallback connectionCallback,
|
||||||
this.classifier = classifier;
|
final OnImageAvailableListener imageListener,
|
||||||
|
final int layout, final int inputSize) {
|
||||||
|
this.cameraConnectionCallback = connectionCallback;
|
||||||
|
this.imageListener = imageListener;
|
||||||
this.layout = layout;
|
this.layout = layout;
|
||||||
this.inputSize = inputSize;
|
this.inputSize = inputSize;
|
||||||
}
|
}
|
||||||
@ -268,8 +270,12 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
final Size[] choices, final int width, final int height, final Size aspectRatio) {
|
final Size[] choices, final int width, final int height, final Size aspectRatio) {
|
||||||
// Collect the supported resolutions that are at least as big as the preview Surface
|
// Collect the supported resolutions that are at least as big as the preview Surface
|
||||||
final List<Size> bigEnough = new ArrayList<Size>();
|
final List<Size> bigEnough = new ArrayList<Size>();
|
||||||
|
|
||||||
|
final int minWidth = Math.max(width, MINIMUM_PREVIEW_SIZE);
|
||||||
|
final int minHeight = Math.max(height, MINIMUM_PREVIEW_SIZE);
|
||||||
|
|
||||||
for (final Size option : choices) {
|
for (final Size option : choices) {
|
||||||
if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) {
|
if (option.getHeight() >= minHeight && option.getWidth() >= minWidth) {
|
||||||
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
|
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
|
||||||
bigEnough.add(option);
|
bigEnough.add(option);
|
||||||
} else {
|
} else {
|
||||||
@ -289,8 +295,9 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static CameraConnectionFragment newInstance(
|
public static CameraConnectionFragment newInstance(
|
||||||
final Classifier classifier, final int layout, final int inputSize) {
|
final ConnectionCallback callback,
|
||||||
return new CameraConnectionFragment(classifier, layout, inputSize);
|
final OnImageAvailableListener imageListener, final int layout, final int inputSize) {
|
||||||
|
return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -302,7 +309,6 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
@Override
|
@Override
|
||||||
public void onViewCreated(final View view, final Bundle savedInstanceState) {
|
public void onViewCreated(final View view, final Bundle savedInstanceState) {
|
||||||
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
|
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
|
||||||
resultsView = (ResultsView) view.findViewById(R.id.results);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -371,7 +377,8 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
|
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
|
||||||
// garbage capture data.
|
// garbage capture data.
|
||||||
previewSize =
|
previewSize =
|
||||||
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height, largest);
|
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
|
||||||
|
inputSize, inputSize, largest);
|
||||||
|
|
||||||
// We fit the aspect ratio of TextureView to the size of preview we picked.
|
// We fit the aspect ratio of TextureView to the size of preview we picked.
|
||||||
final int orientation = getResources().getConfiguration().orientation;
|
final int orientation = getResources().getConfiguration().orientation;
|
||||||
@ -382,6 +389,8 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CameraConnectionFragment.this.cameraId = cameraId;
|
CameraConnectionFragment.this.cameraId = cameraId;
|
||||||
|
|
||||||
|
cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} catch (final CameraAccessException e) {
|
} catch (final CameraAccessException e) {
|
||||||
@ -446,10 +455,6 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
backgroundThread = new HandlerThread("ImageListener");
|
backgroundThread = new HandlerThread("ImageListener");
|
||||||
backgroundThread.start();
|
backgroundThread.start();
|
||||||
backgroundHandler = new Handler(backgroundThread.getLooper());
|
backgroundHandler = new Handler(backgroundThread.getLooper());
|
||||||
|
|
||||||
inferenceThread = new HandlerThread("InferenceThread");
|
|
||||||
inferenceThread.start();
|
|
||||||
inferenceHandler = new Handler(inferenceThread.getLooper());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -457,22 +462,15 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
*/
|
*/
|
||||||
private void stopBackgroundThread() {
|
private void stopBackgroundThread() {
|
||||||
backgroundThread.quitSafely();
|
backgroundThread.quitSafely();
|
||||||
inferenceThread.quitSafely();
|
|
||||||
try {
|
try {
|
||||||
backgroundThread.join();
|
backgroundThread.join();
|
||||||
backgroundThread = null;
|
backgroundThread = null;
|
||||||
backgroundHandler = null;
|
backgroundHandler = null;
|
||||||
|
|
||||||
inferenceThread.join();
|
|
||||||
inferenceThread = null;
|
|
||||||
inferenceThread = null;
|
|
||||||
} catch (final InterruptedException e) {
|
} catch (final InterruptedException e) {
|
||||||
LOGGER.e(e, "Exception!");
|
LOGGER.e(e, "Exception!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private final TensorFlowImageListener tfPreviewListener = new TensorFlowImageListener();
|
|
||||||
|
|
||||||
private final CameraCaptureSession.CaptureCallback captureCallback =
|
private final CameraCaptureSession.CaptureCallback captureCallback =
|
||||||
new CameraCaptureSession.CaptureCallback() {
|
new CameraCaptureSession.CaptureCallback() {
|
||||||
@Override
|
@Override
|
||||||
@ -513,7 +511,7 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
ImageReader.newInstance(
|
ImageReader.newInstance(
|
||||||
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
|
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
|
||||||
|
|
||||||
previewReader.setOnImageAvailableListener(tfPreviewListener, backgroundHandler);
|
previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
|
||||||
previewRequestBuilder.addTarget(previewReader.getSurface());
|
previewRequestBuilder.addTarget(previewReader.getSurface());
|
||||||
|
|
||||||
// Here, we create a CameraCaptureSession for camera preview.
|
// Here, we create a CameraCaptureSession for camera preview.
|
||||||
@ -557,11 +555,6 @@ public class CameraConnectionFragment extends Fragment {
|
|||||||
} catch (final CameraAccessException e) {
|
} catch (final CameraAccessException e) {
|
||||||
LOGGER.e(e, "Exception!");
|
LOGGER.e(e, "Exception!");
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGGER.i("Getting assets.");
|
|
||||||
tfPreviewListener.initialize(
|
|
||||||
classifier, resultsView, inputSize, inferenceHandler, sensorOrientation);
|
|
||||||
LOGGER.i("TensorFlow initialized.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -16,12 +16,29 @@
|
|||||||
|
|
||||||
package org.tensorflow.demo;
|
package org.tensorflow.demo;
|
||||||
|
|
||||||
|
import android.graphics.Bitmap;
|
||||||
|
import android.graphics.Bitmap.Config;
|
||||||
|
import android.graphics.Canvas;
|
||||||
|
import android.graphics.Matrix;
|
||||||
|
import android.graphics.Paint;
|
||||||
|
import android.media.Image;
|
||||||
|
import android.media.Image.Plane;
|
||||||
|
import android.media.ImageReader;
|
||||||
|
import android.media.ImageReader.OnImageAvailableListener;
|
||||||
|
import android.os.SystemClock;
|
||||||
|
import android.os.Trace;
|
||||||
|
import android.util.Size;
|
||||||
|
import android.util.TypedValue;
|
||||||
|
import android.view.Display;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
import android.app.Fragment;
|
import java.util.Vector;
|
||||||
|
import org.tensorflow.demo.OverlayView.DrawCallback;
|
||||||
|
import org.tensorflow.demo.env.BorderedText;
|
||||||
|
import org.tensorflow.demo.env.ImageUtils;
|
||||||
import org.tensorflow.demo.env.Logger;
|
import org.tensorflow.demo.env.Logger;
|
||||||
|
|
||||||
public class ClassifierActivity extends CameraActivity {
|
public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
|
||||||
private static final Logger LOGGER = new Logger();
|
private static final Logger LOGGER = new Logger();
|
||||||
|
|
||||||
// These are the settings for the original v1 Inception model. If you want to
|
// These are the settings for the original v1 Inception model. If you want to
|
||||||
@ -41,9 +58,58 @@ public class ClassifierActivity extends CameraActivity {
|
|||||||
private static final String LABEL_FILE =
|
private static final String LABEL_FILE =
|
||||||
"file:///android_asset/imagenet_comp_graph_label_strings.txt";
|
"file:///android_asset/imagenet_comp_graph_label_strings.txt";
|
||||||
|
|
||||||
|
private static final boolean SAVE_PREVIEW_BITMAP = false;
|
||||||
|
|
||||||
|
private static final boolean MAINTAIN_ASPECT = true;
|
||||||
|
|
||||||
|
private TensorFlowImageClassifier classifier;
|
||||||
|
|
||||||
|
private Integer sensorOrientation;
|
||||||
|
|
||||||
|
private int previewWidth = 0;
|
||||||
|
private int previewHeight = 0;
|
||||||
|
private byte[][] yuvBytes;
|
||||||
|
private int[] rgbBytes = null;
|
||||||
|
private Bitmap rgbFrameBitmap = null;
|
||||||
|
private Bitmap croppedBitmap = null;
|
||||||
|
|
||||||
|
private Bitmap cropCopyBitmap;
|
||||||
|
|
||||||
|
private boolean computing = false;
|
||||||
|
|
||||||
|
private long timestamp = 0;
|
||||||
|
|
||||||
|
private Matrix frameToCropTransform;
|
||||||
|
private Matrix cropToFrameTransform;
|
||||||
|
|
||||||
|
private ResultsView resultsView;
|
||||||
|
|
||||||
|
private OverlayView overlayView;
|
||||||
|
|
||||||
|
private BorderedText borderedText;
|
||||||
|
|
||||||
|
private long lastProcessingTimeMs;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Fragment createFragment() {
|
protected int getLayoutId() {
|
||||||
final TensorFlowImageClassifier classifier = new TensorFlowImageClassifier();
|
return R.layout.camera_connection_fragment;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getDesiredPreviewFrameSize() {
|
||||||
|
return INPUT_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final float TEXT_SIZE_DIP = 18;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onPreviewSizeChosen(final Size size, final int rotation) {
|
||||||
|
final float textSizePx = TypedValue.applyDimension(
|
||||||
|
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
|
||||||
|
getResources().getDisplayMetrics());
|
||||||
|
borderedText = new BorderedText(textSizePx);
|
||||||
|
|
||||||
|
classifier = new TensorFlowImageClassifier();
|
||||||
try {
|
try {
|
||||||
classifier.initializeTensorFlow(
|
classifier.initializeTensorFlow(
|
||||||
getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
|
getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
|
||||||
@ -52,7 +118,151 @@ public class ClassifierActivity extends CameraActivity {
|
|||||||
LOGGER.e(e, "Exception!");
|
LOGGER.e(e, "Exception!");
|
||||||
}
|
}
|
||||||
|
|
||||||
return CameraConnectionFragment.newInstance(
|
overlayView = (OverlayView) findViewById(R.id.overlay);
|
||||||
classifier, R.layout.camera_connection_fragment, INPUT_SIZE);
|
resultsView = (ResultsView) findViewById(R.id.results);
|
||||||
|
previewWidth = size.getWidth();
|
||||||
|
previewHeight = size.getHeight();
|
||||||
|
|
||||||
|
final Display display = getWindowManager().getDefaultDisplay();
|
||||||
|
final int screenOrientation = display.getRotation();
|
||||||
|
|
||||||
|
LOGGER.i("Sensor orientation: %d, Screen orientation: %d",
|
||||||
|
rotation, screenOrientation);
|
||||||
|
|
||||||
|
sensorOrientation = rotation + screenOrientation;
|
||||||
|
|
||||||
|
if (sensorOrientation % 180 == 90) {
|
||||||
|
overlayView.setAspectRatio(size.getHeight(), size.getWidth());
|
||||||
|
} else {
|
||||||
|
overlayView.setAspectRatio(size.getWidth(), size.getHeight());
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
|
||||||
|
rgbBytes = new int[previewWidth * previewHeight];
|
||||||
|
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
|
||||||
|
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
|
||||||
|
|
||||||
|
frameToCropTransform = ImageUtils.getTransformationMatrix(
|
||||||
|
previewWidth, previewHeight,
|
||||||
|
INPUT_SIZE, INPUT_SIZE,
|
||||||
|
sensorOrientation, MAINTAIN_ASPECT);
|
||||||
|
|
||||||
|
cropToFrameTransform = new Matrix();
|
||||||
|
frameToCropTransform.invert(cropToFrameTransform);
|
||||||
|
|
||||||
|
yuvBytes = new byte[3][];
|
||||||
|
|
||||||
|
overlayView.addCallback(new DrawCallback() {
|
||||||
|
@Override
|
||||||
|
public void drawCallback(final Canvas canvas) {
|
||||||
|
renderDebug(canvas);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onImageAvailable(final ImageReader reader) {
|
||||||
|
Image image = null;
|
||||||
|
|
||||||
|
++timestamp;
|
||||||
|
|
||||||
|
try {
|
||||||
|
image = reader.acquireLatestImage();
|
||||||
|
|
||||||
|
if (image == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (computing) {
|
||||||
|
image.close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
computing = true;
|
||||||
|
|
||||||
|
Trace.beginSection("imageAvailable");
|
||||||
|
|
||||||
|
final Plane[] planes = image.getPlanes();
|
||||||
|
fillBytes(planes, yuvBytes);
|
||||||
|
|
||||||
|
final int yRowStride = planes[0].getRowStride();
|
||||||
|
final int uvRowStride = planes[1].getRowStride();
|
||||||
|
final int uvPixelStride = planes[1].getPixelStride();
|
||||||
|
ImageUtils.convertYUV420ToARGB8888(
|
||||||
|
yuvBytes[0],
|
||||||
|
yuvBytes[1],
|
||||||
|
yuvBytes[2],
|
||||||
|
rgbBytes,
|
||||||
|
previewWidth,
|
||||||
|
previewHeight,
|
||||||
|
yRowStride,
|
||||||
|
uvRowStride,
|
||||||
|
uvPixelStride,
|
||||||
|
false);
|
||||||
|
|
||||||
|
image.close();
|
||||||
|
} catch (final Exception e) {
|
||||||
|
if (image != null) {
|
||||||
|
image.close();
|
||||||
|
}
|
||||||
|
LOGGER.e(e, "Exception!");
|
||||||
|
Trace.endSection();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
|
||||||
|
final Canvas canvas = new Canvas(croppedBitmap);
|
||||||
|
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
|
||||||
|
|
||||||
|
// For examining the actual TF input.
|
||||||
|
if (SAVE_PREVIEW_BITMAP) {
|
||||||
|
ImageUtils.saveBitmap(croppedBitmap);
|
||||||
|
}
|
||||||
|
|
||||||
|
runInBackground(
|
||||||
|
new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
final long startTime = SystemClock.uptimeMillis();
|
||||||
|
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
|
||||||
|
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
|
||||||
|
|
||||||
|
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
|
||||||
|
resultsView.setResults(results);
|
||||||
|
overlayView.postInvalidate();
|
||||||
|
computing = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Trace.endSection();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void renderDebug(Canvas canvas) {
|
||||||
|
if (!isDebug()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
final Bitmap copy = cropCopyBitmap;
|
||||||
|
if (copy != null) {
|
||||||
|
final Matrix matrix = new Matrix();
|
||||||
|
final float scaleFactor = 2;
|
||||||
|
matrix.postScale(scaleFactor, scaleFactor);
|
||||||
|
matrix.postTranslate(
|
||||||
|
canvas.getWidth() - copy.getWidth() * scaleFactor,
|
||||||
|
canvas.getHeight() - copy.getHeight() * scaleFactor);
|
||||||
|
canvas.drawBitmap(copy, matrix, new Paint());
|
||||||
|
|
||||||
|
final Vector<String> lines = new Vector<String>();
|
||||||
|
lines.add("Frame: " + previewWidth + "x" + previewHeight);
|
||||||
|
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
|
||||||
|
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
|
||||||
|
lines.add("Rotation: " + sensorOrientation);
|
||||||
|
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
|
||||||
|
|
||||||
|
int lineNum = 0;
|
||||||
|
for (final String line : lines) {
|
||||||
|
borderedText.drawText(canvas, 10,
|
||||||
|
canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, line);
|
||||||
|
++lineNum;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.demo;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.graphics.Canvas;
|
||||||
|
import android.util.AttributeSet;
|
||||||
|
import android.view.MotionEvent;
|
||||||
|
import android.view.View;
|
||||||
|
import android.view.View.MeasureSpec;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple View providing a render callback to other classes.
|
||||||
|
*/
|
||||||
|
public class OverlayView extends View {
|
||||||
|
public OverlayView(final Context context, final AttributeSet attrs) {
|
||||||
|
super(context, attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface defining the callback for client classes.
|
||||||
|
*/
|
||||||
|
public interface DrawCallback {
|
||||||
|
public void drawCallback(final Canvas canvas);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int ratioWidth;
|
||||||
|
private int ratioHeight;
|
||||||
|
|
||||||
|
private boolean debug;
|
||||||
|
|
||||||
|
private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean onTouchEvent(final MotionEvent e) {
|
||||||
|
super.onTouchEvent(e);
|
||||||
|
if (e.getAction() == MotionEvent.ACTION_DOWN) {
|
||||||
|
debug = !debug;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addCallback(final DrawCallback callback) {
|
||||||
|
callbacks.add(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void draw(final Canvas canvas) {
|
||||||
|
for (final DrawCallback callback : callbacks) {
|
||||||
|
callback.drawCallback(canvas);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAspectRatio(final int width, final int height) {
|
||||||
|
if (width < 0 || height < 0) {
|
||||||
|
throw new IllegalArgumentException("Size cannot be negative.");
|
||||||
|
}
|
||||||
|
ratioWidth = width;
|
||||||
|
ratioHeight = height;
|
||||||
|
requestLayout();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
|
||||||
|
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
|
||||||
|
final int width = MeasureSpec.getSize(widthMeasureSpec);
|
||||||
|
final int height = MeasureSpec.getSize(heightMeasureSpec);
|
||||||
|
if (0 == ratioWidth || 0 == ratioHeight) {
|
||||||
|
setMeasuredDimension(width, height);
|
||||||
|
} else {
|
||||||
|
if (width < height * ratioWidth / ratioHeight) {
|
||||||
|
setMeasuredDimension(width, width * ratioHeight / ratioWidth);
|
||||||
|
} else {
|
||||||
|
setMeasuredDimension(height * ratioWidth / ratioHeight, height);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
119
tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java
vendored
Normal file
119
tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java
vendored
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.demo.env;
|
||||||
|
|
||||||
|
import android.graphics.Canvas;
|
||||||
|
import android.graphics.Color;
|
||||||
|
import android.graphics.Paint;
|
||||||
|
import android.graphics.Paint.Align;
|
||||||
|
import android.graphics.Paint.Style;
|
||||||
|
import android.graphics.Rect;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
|
||||||
|
*/
|
||||||
|
public class BorderedText {
|
||||||
|
private final Paint interiorPaint;
|
||||||
|
private final Paint exteriorPaint;
|
||||||
|
|
||||||
|
private final float textSize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a left-aligned bordered text object with a white interior, and a black exterior with
|
||||||
|
* the specified text size.
|
||||||
|
*
|
||||||
|
* @param textSize text size in pixels
|
||||||
|
*/
|
||||||
|
public BorderedText(final float textSize) {
|
||||||
|
this(Color.WHITE, Color.BLACK, textSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a bordered text object with the specified interior and exterior colors, text size and
|
||||||
|
* alignment.
|
||||||
|
*
|
||||||
|
* @param interiorColor the interior text color
|
||||||
|
* @param exteriorColor the exterior text color
|
||||||
|
* @param textSize text size in pixels
|
||||||
|
*/
|
||||||
|
public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
|
||||||
|
interiorPaint = new Paint();
|
||||||
|
interiorPaint.setTextSize(textSize);
|
||||||
|
interiorPaint.setColor(interiorColor);
|
||||||
|
interiorPaint.setStyle(Style.FILL);
|
||||||
|
interiorPaint.setAntiAlias(false);
|
||||||
|
interiorPaint.setAlpha(255);
|
||||||
|
|
||||||
|
exteriorPaint = new Paint();
|
||||||
|
exteriorPaint.setTextSize(textSize);
|
||||||
|
exteriorPaint.setColor(exteriorColor);
|
||||||
|
exteriorPaint.setStyle(Style.FILL_AND_STROKE);
|
||||||
|
exteriorPaint.setStrokeWidth(textSize / 8);
|
||||||
|
exteriorPaint.setAntiAlias(false);
|
||||||
|
exteriorPaint.setAlpha(255);
|
||||||
|
|
||||||
|
this.textSize = textSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
|
||||||
|
/*
|
||||||
|
if (widths == null || widths.length < text.length()) {
|
||||||
|
widths = new float[text.length()];
|
||||||
|
positions = new float[text.length() * 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
exteriorPaint.getTextWidths(text, widths);
|
||||||
|
float lastPosX = posX;
|
||||||
|
for (int i = 0; i < widths.length; ++i) {
|
||||||
|
positions[i * 2] = lastPosX;
|
||||||
|
positions[i * 2 + 1] = posY;
|
||||||
|
lastPosX += widths[i];
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
//canvas.drawPosText(text, positions, exteriorPaint);
|
||||||
|
//canvas.drawPosText(text, positions, exteriorPaint);
|
||||||
|
canvas.drawText(text, posX, posY, exteriorPaint);
|
||||||
|
canvas.drawText(text, posX, posY, interiorPaint);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setInteriorColor(final int color) {
|
||||||
|
interiorPaint.setColor(color);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setExteriorColor(final int color) {
|
||||||
|
exteriorPaint.setColor(color);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getTextSize() {
|
||||||
|
return textSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAlpha(final int alpha) {
|
||||||
|
interiorPaint.setAlpha(alpha);
|
||||||
|
exteriorPaint.setAlpha(alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void getTextBounds(
|
||||||
|
final String line, final int index, final int count, final Rect lineBounds) {
|
||||||
|
interiorPaint.getTextBounds(line, index, count, lineBounds);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTextAlign(final Align align) {
|
||||||
|
interiorPaint.setTextAlign(align);
|
||||||
|
exteriorPaint.setTextAlign(align);
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user