Merge pull request #4220 from caisq/branch_132225803
Merge Changes from Internal: Branch 132225803
This commit is contained in:
commit
7a45bc5e7f
@ -29,6 +29,15 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "android_arm64",
|
||||||
|
values = {
|
||||||
|
"crosstool_top": "//external:android/crosstool",
|
||||||
|
"android_cpu": "arm64-v8a",
|
||||||
|
},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "darwin",
|
name = "darwin",
|
||||||
values = {"cpu": "darwin"},
|
values = {"cpu": "darwin"},
|
||||||
@ -95,6 +104,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/ffmpeg/default:all_files",
|
"//tensorflow/contrib/ffmpeg/default:all_files",
|
||||||
"//tensorflow/contrib/framework:all_files",
|
"//tensorflow/contrib/framework:all_files",
|
||||||
"//tensorflow/contrib/graph_editor:all_files",
|
"//tensorflow/contrib/graph_editor:all_files",
|
||||||
|
"//tensorflow/contrib/grid_rnn:all_files",
|
||||||
"//tensorflow/contrib/layers:all_files",
|
"//tensorflow/contrib/layers:all_files",
|
||||||
"//tensorflow/contrib/layers/kernels:all_files",
|
"//tensorflow/contrib/layers/kernels:all_files",
|
||||||
"//tensorflow/contrib/learn:all_files",
|
"//tensorflow/contrib/learn:all_files",
|
||||||
|
@ -87,7 +87,7 @@ TEST(CApi, AllocateTensor) {
|
|||||||
static void TestEncodeDecode(int line,
|
static void TestEncodeDecode(int line,
|
||||||
const std::vector<tensorflow::string>& data) {
|
const std::vector<tensorflow::string>& data) {
|
||||||
const tensorflow::int64 n = data.size();
|
const tensorflow::int64 n = data.size();
|
||||||
for (std::vector<tensorflow::int64> dims :
|
for (const std::vector<tensorflow::int64>& dims :
|
||||||
std::vector<std::vector<tensorflow::int64>>{
|
std::vector<std::vector<tensorflow::int64>>{
|
||||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||||
// Create C++ Tensor
|
// Create C++ Tensor
|
||||||
|
@ -37,7 +37,7 @@ namespace tensorflow {
|
|||||||
namespace example {
|
namespace example {
|
||||||
|
|
||||||
struct Options {
|
struct Options {
|
||||||
int num_concurrent_sessions = 10; // The number of concurrent sessions
|
int num_concurrent_sessions = 1; // The number of concurrent sessions
|
||||||
int num_concurrent_steps = 10; // The number of concurrent steps
|
int num_concurrent_steps = 10; // The number of concurrent steps
|
||||||
int num_iterations = 100; // Each step repeats this many times
|
int num_iterations = 100; // Each step repeats this many times
|
||||||
bool use_gpu = false; // Whether to use gpu in the training
|
bool use_gpu = false; // Whether to use gpu in the training
|
||||||
@ -108,10 +108,11 @@ void ConcurrentSteps(const Options* opts, int session_index) {
|
|||||||
|
|
||||||
// Spawn M threads for M concurrent steps.
|
// Spawn M threads for M concurrent steps.
|
||||||
const int M = opts->num_concurrent_steps;
|
const int M = opts->num_concurrent_steps;
|
||||||
thread::ThreadPool step_threads(Env::Default(), "trainer", M);
|
std::unique_ptr<thread::ThreadPool> step_threads(
|
||||||
|
new thread::ThreadPool(Env::Default(), "trainer", M));
|
||||||
|
|
||||||
for (int step = 0; step < M; ++step) {
|
for (int step = 0; step < M; ++step) {
|
||||||
step_threads.Schedule([&session, opts, session_index, step]() {
|
step_threads->Schedule([&session, opts, session_index, step]() {
|
||||||
// Randomly initialize the input.
|
// Randomly initialize the input.
|
||||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||||
auto x_flat = x.flat<float>();
|
auto x_flat = x.flat<float>();
|
||||||
@ -139,12 +140,19 @@ void ConcurrentSteps(const Options* opts, int session_index) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete the threadpool, thus waiting for all threads to complete.
|
||||||
|
step_threads.reset(nullptr);
|
||||||
TF_CHECK_OK(session->Close());
|
TF_CHECK_OK(session->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConcurrentSessions(const Options& opts) {
|
void ConcurrentSessions(const Options& opts) {
|
||||||
// Spawn N threads for N concurrent sessions.
|
// Spawn N threads for N concurrent sessions.
|
||||||
const int N = opts.num_concurrent_sessions;
|
const int N = opts.num_concurrent_sessions;
|
||||||
|
|
||||||
|
// At the moment our Session implementation only allows
|
||||||
|
// one concurrently computing Session on GPU.
|
||||||
|
CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
|
||||||
|
|
||||||
thread::ThreadPool session_threads(Env::Default(), "trainer", N);
|
thread::ThreadPool session_threads(Env::Default(), "trainer", N);
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
|
session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
|
||||||
|
@ -23,6 +23,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/entropy_test.py"],
|
srcs = ["python/kernel_tests/entropy_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -34,6 +35,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/monte_carlo_test.py"],
|
srcs = ["python/kernel_tests/monte_carlo_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -45,6 +47,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/special_math_test.py"],
|
srcs = ["python/kernel_tests/special_math_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -56,6 +59,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
|
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -67,6 +71,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/variational_inference_test.py"],
|
srcs = ["python/kernel_tests/variational_inference_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -78,6 +83,7 @@ cuda_py_test(
|
|||||||
srcs = ["python/kernel_tests/stochastic_tensor_test.py"],
|
srcs = ["python/kernel_tests/stochastic_tensor_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -89,6 +95,7 @@ cuda_py_test(
|
|||||||
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
|
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":bayesflow_py",
|
":bayesflow_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
@ -159,6 +159,21 @@ class NdtrGradientTest(tf.test.TestCase):
|
|||||||
_use_log = False
|
_use_log = False
|
||||||
_grid = GridSpec(min=-100., max=100., shape=[1, 2, 3, 8])
|
_grid = GridSpec(min=-100., max=100., shape=[1, 2, 3, 8])
|
||||||
|
|
||||||
|
def assert_all_true(self, v):
|
||||||
|
self.assertAllEqual(np.ones_like(v, dtype=np.bool), v)
|
||||||
|
|
||||||
|
def assert_all_false(self, v):
|
||||||
|
self.assertAllEqual(np.zeros_like(v, dtype=np.bool), v)
|
||||||
|
|
||||||
|
def _test_grad_finite(self, dtype):
|
||||||
|
with self.test_session():
|
||||||
|
x = tf.Variable([-100., 0., 100.], dtype=dtype)
|
||||||
|
output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x))
|
||||||
|
grad_output = tf.gradients(output, x)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
self.assert_all_true(np.isfinite(output.eval()))
|
||||||
|
self.assert_all_true(np.isfinite(grad_output[0].eval()))
|
||||||
|
|
||||||
def _test_grads_are_positive(self, dtype, grid_spec):
|
def _test_grads_are_positive(self, dtype, grid_spec):
|
||||||
grid = tf.convert_to_tensor(_make_grid(dtype, grid_spec))
|
grid = tf.convert_to_tensor(_make_grid(dtype, grid_spec))
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -169,20 +184,24 @@ class NdtrGradientTest(tf.test.TestCase):
|
|||||||
# grad_eval.shape = (N, N), with grad_eval[i, j] the partial derivative of
|
# grad_eval.shape = (N, N), with grad_eval[i, j] the partial derivative of
|
||||||
# the ith output point w.r.t. the jth grid point. We only expect the
|
# the ith output point w.r.t. the jth grid point. We only expect the
|
||||||
# diagonal to be nonzero.
|
# diagonal to be nonzero.
|
||||||
|
# TODO(b/31131137): Replace tf.test.compute_gradient with our own custom
|
||||||
|
# gradient evaluation to ensure we correctly handle small function delta.
|
||||||
grad_eval, _ = tf.test.compute_gradient(
|
grad_eval, _ = tf.test.compute_gradient(
|
||||||
grid, grid_spec.shape, output, grid_spec.shape)
|
grid, grid_spec.shape, output, grid_spec.shape)
|
||||||
grad_eval = np.diag(grad_eval)
|
grad_eval = np.diag(grad_eval)
|
||||||
|
|
||||||
# Check for NaN separately in order to get informative failures.
|
# Check for NaN separately in order to get informative failures.
|
||||||
self.assertFalse(np.isnan(grad_eval).any())
|
self.assert_all_false(np.isnan(grad_eval))
|
||||||
self.assertTrue((grad_eval > 0).all())
|
self.assert_all_true(grad_eval > 0.)
|
||||||
self.assertTrue(np.isfinite(grad_eval).all())
|
self.assert_all_true(np.isfinite(grad_eval))
|
||||||
|
|
||||||
def test_float32(self):
|
def test_float32(self):
|
||||||
self._test_grads_are_positive(np.float32, self._grid)
|
self._test_grads_are_positive(np.float32, self._grid)
|
||||||
|
self._test_grad_finite(np.float32)
|
||||||
|
|
||||||
def test_float64(self):
|
def test_float64(self):
|
||||||
self._test_grads_are_positive(np.float64, self._grid)
|
self._test_grads_are_positive(np.float64, self._grid)
|
||||||
|
self._test_grad_finite(np.float64)
|
||||||
|
|
||||||
|
|
||||||
class LogNdtrGradientTest(NdtrGradientTest):
|
class LogNdtrGradientTest(NdtrGradientTest):
|
||||||
|
@ -174,13 +174,18 @@ def log_ndtr(x, series_order=3, name=None):
|
|||||||
# * We use one fixed series_order for all of 'x', rather than adaptive.
|
# * We use one fixed series_order for all of 'x', rather than adaptive.
|
||||||
# * Our docstring properly reflects that this is an asymptotic series, not a
|
# * Our docstring properly reflects that this is an asymptotic series, not a
|
||||||
# Tayor series. We also provided a correct bound on the remainder.
|
# Tayor series. We also provided a correct bound on the remainder.
|
||||||
|
# * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
|
||||||
|
# x=0. This happens even though the branch is unchosen because when x=0
|
||||||
|
# the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
|
||||||
|
# regardless of whether dy is finite. Note that the minimum is a NOP if
|
||||||
|
# the branch is chosen.
|
||||||
return math_ops.select(
|
return math_ops.select(
|
||||||
math_ops.greater(x, upper_segment),
|
math_ops.greater(x, upper_segment),
|
||||||
-_ndtr(-x), # log(1-x) ~= -x, x << 1
|
-_ndtr(-x), # log(1-x) ~= -x, x << 1
|
||||||
math_ops.select(math_ops.greater(x, lower_segment),
|
math_ops.select(math_ops.greater(x, lower_segment),
|
||||||
math_ops.log(_ndtr(x)),
|
math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))),
|
||||||
_log_ndtr_lower(x, series_order)))
|
_log_ndtr_lower(math_ops.minimum(x, lower_segment),
|
||||||
|
series_order)))
|
||||||
|
|
||||||
|
|
||||||
def _log_ndtr_lower(x, series_order):
|
def _log_ndtr_lower(x, series_order):
|
||||||
|
@ -16,6 +16,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -27,6 +28,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -38,6 +40,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_diag_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_diag_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -49,6 +52,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -60,6 +64,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -71,6 +76,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
|
srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -89,6 +95,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/bernoulli_test.py"],
|
srcs = ["python/kernel_tests/bernoulli_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -99,6 +106,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/beta_test.py"],
|
srcs = ["python/kernel_tests/beta_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
tags = ["notsan"], #http://b/31216497
|
tags = ["notsan"], #http://b/31216497
|
||||||
@ -110,6 +118,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/binomial_test.py"],
|
srcs = ["python/kernel_tests/binomial_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -120,6 +129,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/categorical_test.py"],
|
srcs = ["python/kernel_tests/categorical_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -129,6 +139,7 @@ cuda_py_tests(
|
|||||||
name = "chi2_test",
|
name = "chi2_test",
|
||||||
srcs = ["python/kernel_tests/chi2_test.py"],
|
srcs = ["python/kernel_tests/chi2_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -140,6 +151,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/dirichlet_test.py"],
|
srcs = ["python/kernel_tests/dirichlet_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -151,6 +163,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
|
srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -161,6 +174,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/exponential_test.py"],
|
srcs = ["python/kernel_tests/exponential_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -170,6 +184,7 @@ cuda_py_tests(
|
|||||||
name = "gamma_test",
|
name = "gamma_test",
|
||||||
srcs = ["python/kernel_tests/gamma_test.py"],
|
srcs = ["python/kernel_tests/gamma_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -180,6 +195,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/inverse_gamma_test.py"],
|
srcs = ["python/kernel_tests/inverse_gamma_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -190,6 +206,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -200,6 +217,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/multinomial_test.py"],
|
srcs = ["python/kernel_tests/multinomial_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -211,6 +229,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/mvn_test.py"],
|
srcs = ["python/kernel_tests/mvn_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -222,6 +241,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/mixture_test.py"],
|
srcs = ["python/kernel_tests/mixture_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -233,6 +253,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/normal_test.py"],
|
srcs = ["python/kernel_tests/normal_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -244,6 +265,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/poisson_test.py"],
|
srcs = ["python/kernel_tests/poisson_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -255,6 +277,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/student_t_test.py"],
|
srcs = ["python/kernel_tests/student_t_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -266,6 +289,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/uniform_test.py"],
|
srcs = ["python/kernel_tests/uniform_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -277,6 +301,7 @@ cuda_py_tests(
|
|||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -286,6 +311,7 @@ cuda_py_tests(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -296,6 +322,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/normal_conjugate_posteriors_test.py"],
|
srcs = ["python/kernel_tests/normal_conjugate_posteriors_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -306,6 +333,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/transformed_distribution_test.py"],
|
srcs = ["python/kernel_tests/transformed_distribution_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -316,6 +344,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/distribution_util_test.py"],
|
srcs = ["python/kernel_tests/distribution_util_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -327,6 +356,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/shape_test.py"],
|
srcs = ["python/kernel_tests/shape_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -338,6 +368,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/bijector_test.py"],
|
srcs = ["python/kernel_tests/bijector_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
@ -27,6 +27,14 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class NormalTest(tf.test.TestCase):
|
class NormalTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(123)
|
||||||
|
|
||||||
|
def assertAllFinite(self, tensor):
|
||||||
|
is_finite = np.isfinite(tensor.eval())
|
||||||
|
all_true = np.ones_like(is_finite, dtype=np.bool)
|
||||||
|
self.assertAllEqual(all_true, is_finite)
|
||||||
|
|
||||||
def _testParamShapes(self, sample_shape, expected):
|
def _testParamShapes(self, sample_shape, expected):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
param_shapes = tf.contrib.distributions.Normal.param_shapes(sample_shape)
|
param_shapes = tf.contrib.distributions.Normal.param_shapes(sample_shape)
|
||||||
@ -143,21 +151,94 @@ class NormalTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testNormalCDF(self):
|
def testNormalCDF(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
batch_size = 6
|
batch_size = 50
|
||||||
mu = tf.constant([3.0] * batch_size)
|
mu = self._rng.randn(batch_size)
|
||||||
sigma = tf.constant([math.sqrt(10.0)] * batch_size)
|
sigma = self._rng.rand(batch_size) + 1.0
|
||||||
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
|
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||||
|
|
||||||
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||||
expected_cdf = stats.norm(mu.eval(), sigma.eval()).cdf(x)
|
expected_cdf = stats.norm(mu, sigma).cdf(x)
|
||||||
|
|
||||||
cdf = normal.cdf(x)
|
cdf = normal.cdf(x)
|
||||||
self.assertAllClose(expected_cdf, cdf.eval())
|
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
|
||||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
|
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
|
||||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
|
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
|
||||||
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
|
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
|
||||||
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
|
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
|
||||||
|
|
||||||
|
def testNormalSurvivalFunction(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
mu = self._rng.randn(batch_size)
|
||||||
|
sigma = self._rng.rand(batch_size) + 1.0
|
||||||
|
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||||
|
|
||||||
|
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||||
|
expected_sf = stats.norm(mu, sigma).sf(x)
|
||||||
|
|
||||||
|
sf = normal.survival_function(x)
|
||||||
|
self.assertAllClose(expected_sf, sf.eval(), atol=0)
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), sf.get_shape())
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), sf.eval().shape)
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), sf.get_shape())
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), sf.eval().shape)
|
||||||
|
|
||||||
|
def testNormalLogCDF(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
mu = self._rng.randn(batch_size)
|
||||||
|
sigma = self._rng.rand(batch_size) + 1.0
|
||||||
|
x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
|
||||||
|
|
||||||
|
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||||
|
expected_cdf = stats.norm(mu, sigma).logcdf(x)
|
||||||
|
|
||||||
|
cdf = normal.log_cdf(x)
|
||||||
|
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
|
||||||
|
|
||||||
|
def testFiniteGradientAtDifficultPoints(self):
|
||||||
|
with self.test_session():
|
||||||
|
for dtype in [np.float32, np.float64]:
|
||||||
|
mu = tf.Variable(dtype(0.0))
|
||||||
|
sigma = tf.Variable(dtype(1.0))
|
||||||
|
dist = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for func in [
|
||||||
|
dist.cdf,
|
||||||
|
dist.log_cdf,
|
||||||
|
dist.survival_function,
|
||||||
|
dist.log_survival_function,
|
||||||
|
dist.log_prob,
|
||||||
|
dist.prob]:
|
||||||
|
x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype)
|
||||||
|
value = func(x)
|
||||||
|
grads = tf.gradients(value, [mu, sigma])
|
||||||
|
|
||||||
|
self.assertAllFinite(value)
|
||||||
|
self.assertAllFinite(grads[0])
|
||||||
|
self.assertAllFinite(grads[1])
|
||||||
|
|
||||||
|
def testNormalLogSurvivalFunction(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
mu = self._rng.randn(batch_size)
|
||||||
|
sigma = self._rng.rand(batch_size) + 1.0
|
||||||
|
x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
|
||||||
|
|
||||||
|
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||||
|
expected_sf = stats.norm(mu, sigma).logsf(x)
|
||||||
|
|
||||||
|
sf = normal.log_survival_function(x)
|
||||||
|
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), sf.get_shape())
|
||||||
|
self.assertAllEqual(normal.batch_shape().eval(), sf.eval().shape)
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), sf.get_shape())
|
||||||
|
self.assertAllEqual(normal.get_batch_shape(), sf.eval().shape)
|
||||||
|
|
||||||
def testNormalEntropyWithScalarInputs(self):
|
def testNormalEntropyWithScalarInputs(self):
|
||||||
# Scipy.stats.norm cannot deal with the shapes in the other test.
|
# Scipy.stats.norm cannot deal with the shapes in the other test.
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -540,6 +540,16 @@ class Distribution(BaseDistribution):
|
|||||||
def log_cdf(self, value, name="log_cdf"):
|
def log_cdf(self, value, name="log_cdf"):
|
||||||
"""Log cumulative distribution function.
|
"""Log cumulative distribution function.
|
||||||
|
|
||||||
|
Given random variable `X`, the cumulative distribution function `cdf` is:
|
||||||
|
|
||||||
|
```
|
||||||
|
log_cdf(x) := Log[ P[X <= x] ]
|
||||||
|
```
|
||||||
|
|
||||||
|
Often, a numerical approximation can be used for `log_cdf(x)` that yields
|
||||||
|
a more accurate answer than simply taking the logarithm of the `cdf` when
|
||||||
|
`x << -1`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: `float` or `double` `Tensor`.
|
value: `float` or `double` `Tensor`.
|
||||||
name: The name to give this op.
|
name: The name to give this op.
|
||||||
@ -556,6 +566,12 @@ class Distribution(BaseDistribution):
|
|||||||
def cdf(self, value, name="cdf"):
|
def cdf(self, value, name="cdf"):
|
||||||
"""Cumulative distribution function.
|
"""Cumulative distribution function.
|
||||||
|
|
||||||
|
Given random variable `X`, the cumulative distribution function `cdf` is:
|
||||||
|
|
||||||
|
```
|
||||||
|
cdf(x) := P[X <= x]
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: `float` or `double` `Tensor`.
|
value: `float` or `double` `Tensor`.
|
||||||
name: The name to give this op.
|
name: The name to give this op.
|
||||||
@ -569,6 +585,57 @@ class Distribution(BaseDistribution):
|
|||||||
value = ops.convert_to_tensor(value, name="value")
|
value = ops.convert_to_tensor(value, name="value")
|
||||||
return self._cdf(value)
|
return self._cdf(value)
|
||||||
|
|
||||||
|
def log_survival_function(self, value, name="log_survival_function"):
|
||||||
|
"""Log survival function.
|
||||||
|
|
||||||
|
Given random variable `X`, the survival function is defined:
|
||||||
|
|
||||||
|
```
|
||||||
|
log_survival_function(x) = Log[ P[X > x] ]
|
||||||
|
= Log[ 1 - P[X <= x] ]
|
||||||
|
= Log[ 1 - cdf(x) ]
|
||||||
|
```
|
||||||
|
|
||||||
|
Typically, different numerical approximations can be used for the log
|
||||||
|
survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: `float` or `double` `Tensor`.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
||||||
|
`self.dtype`.
|
||||||
|
"""
|
||||||
|
self._check_hasattr(self._log_survival_function)
|
||||||
|
with self._name_scope(name, values=[value]):
|
||||||
|
value = ops.convert_to_tensor(value, name="value")
|
||||||
|
return self._log_survival_function(value)
|
||||||
|
|
||||||
|
def survival_function(self, value, name="survival_function"):
|
||||||
|
"""Survival function.
|
||||||
|
|
||||||
|
Given random variable `X`, the survival function is defined:
|
||||||
|
|
||||||
|
```
|
||||||
|
survival_function(x) = P[X > x]
|
||||||
|
= 1 - P[X <= x]
|
||||||
|
= 1 - cdf(x).
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: `float` or `double` `Tensor`.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
||||||
|
`self.dtype`.
|
||||||
|
"""
|
||||||
|
self._check_hasattr(self._survival_function)
|
||||||
|
with self._name_scope(name, values=[value]):
|
||||||
|
value = ops.convert_to_tensor(value, name="value")
|
||||||
|
return self._survival_function(value)
|
||||||
|
|
||||||
def entropy(self, name="entropy"):
|
def entropy(self, name="entropy"):
|
||||||
"""Shanon entropy in nats."""
|
"""Shanon entropy in nats."""
|
||||||
self._check_hasattr(self._entropy)
|
self._check_hasattr(self._entropy)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from tensorflow.contrib.bayesflow.python.ops import special_math
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
||||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||||
@ -169,20 +170,22 @@ class Normal(distribution.Distribution):
|
|||||||
|
|
||||||
def _log_prob(self, x):
|
def _log_prob(self, x):
|
||||||
return (-0.5 * math.log(2. * math.pi) - math_ops.log(self.sigma)
|
return (-0.5 * math.log(2. * math.pi) - math_ops.log(self.sigma)
|
||||||
-0.5 * math_ops.square((x - self.mu) / self.sigma))
|
-0.5 * math_ops.square(self._z(x)))
|
||||||
|
|
||||||
def _prob(self, x):
|
def _prob(self, x):
|
||||||
return math_ops.exp(self._log_prob(x))
|
return math_ops.exp(self._log_prob(x))
|
||||||
|
|
||||||
def _log_cdf(self, x):
|
def _log_cdf(self, x):
|
||||||
return math_ops.log(self._cdf(x))
|
return special_math.log_ndtr(self._z(x))
|
||||||
|
|
||||||
def _cdf(self, x):
|
def _cdf(self, x):
|
||||||
# TODO(ebrevdo): wrap this in a Defun with a custom Defun
|
return special_math.ndtr(self._z(x))
|
||||||
# gradient because the analytic gradient may be faster than
|
|
||||||
# automatic differentiation.
|
def _log_survival_function(self, x):
|
||||||
return (0.5 + 0.5*math_ops.erf(
|
return special_math.log_ndtr(-self._z(x))
|
||||||
1. / (math.sqrt(2.) * self.sigma) * (x - self.mu)))
|
|
||||||
|
def _survival_function(self, x):
|
||||||
|
return special_math.ndtr(-self._z(x))
|
||||||
|
|
||||||
def _entropy(self):
|
def _entropy(self):
|
||||||
# Use broadcasting rules to calculate the full broadcast sigma.
|
# Use broadcasting rules to calculate the full broadcast sigma.
|
||||||
@ -201,6 +204,11 @@ class Normal(distribution.Distribution):
|
|||||||
def _mode(self):
|
def _mode(self):
|
||||||
return self._mean()
|
return self._mean()
|
||||||
|
|
||||||
|
def _z(self, x):
|
||||||
|
"""Standardize input `x` to a unit normal."""
|
||||||
|
with ops.name_scope("standardize", values=[x]):
|
||||||
|
return (x - self.mu) / self.sigma
|
||||||
|
|
||||||
|
|
||||||
@kullback_leibler.RegisterKL(Normal, Normal)
|
@kullback_leibler.RegisterKL(Normal, Normal)
|
||||||
def _kl_normal_normal(n_a, n_b, name=None):
|
def _kl_normal_normal(n_a, n_b, name=None):
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.contrib.factorization.python.ops import gmm_ops
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
|
|
||||||
|
|
||||||
@ -166,12 +167,17 @@ class GMM(estimator.Estimator, TransformerMixin):
|
|||||||
self.model_dir,
|
self.model_dir,
|
||||||
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
|
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
|
||||||
|
|
||||||
|
def _parse_tensor_or_dict(self, features):
|
||||||
|
if isinstance(features, dict):
|
||||||
|
return array_ops.concat(1, [features[k] for k in sorted(features.keys())])
|
||||||
|
return features
|
||||||
|
|
||||||
def _get_train_ops(self, features, _):
|
def _get_train_ops(self, features, _):
|
||||||
(_,
|
(_,
|
||||||
_,
|
_,
|
||||||
losses,
|
losses,
|
||||||
training_op) = gmm_ops.gmm(
|
training_op) = gmm_ops.gmm(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._random_seed,
|
self._random_seed,
|
||||||
@ -187,7 +193,7 @@ class GMM(estimator.Estimator, TransformerMixin):
|
|||||||
model_predictions,
|
model_predictions,
|
||||||
_,
|
_,
|
||||||
_) = gmm_ops.gmm(
|
_) = gmm_ops.gmm(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._random_seed,
|
self._random_seed,
|
||||||
@ -203,7 +209,7 @@ class GMM(estimator.Estimator, TransformerMixin):
|
|||||||
_,
|
_,
|
||||||
losses,
|
losses,
|
||||||
_) = gmm_ops.gmm(
|
_) = gmm_ops.gmm(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._random_seed,
|
self._random_seed,
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
|
|
||||||
SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
|
SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
|
||||||
@ -222,12 +223,17 @@ class KMeansClustering(estimator.Estimator,
|
|||||||
"""Returns cluster centers."""
|
"""Returns cluster centers."""
|
||||||
return tf.contrib.framework.load_variable(self.model_dir, self.CLUSTERS)
|
return tf.contrib.framework.load_variable(self.model_dir, self.CLUSTERS)
|
||||||
|
|
||||||
|
def _parse_tensor_or_dict(self, features):
|
||||||
|
if isinstance(features, dict):
|
||||||
|
return array_ops.concat(1, [features[k] for k in sorted(features.keys())])
|
||||||
|
return features
|
||||||
|
|
||||||
def _get_train_ops(self, features, _):
|
def _get_train_ops(self, features, _):
|
||||||
(_,
|
(_,
|
||||||
_,
|
_,
|
||||||
losses,
|
losses,
|
||||||
training_op) = clustering_ops.KMeans(
|
training_op) = clustering_ops.KMeans(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._distance_metric,
|
self._distance_metric,
|
||||||
@ -245,7 +251,7 @@ class KMeansClustering(estimator.Estimator,
|
|||||||
model_predictions,
|
model_predictions,
|
||||||
_,
|
_,
|
||||||
_) = clustering_ops.KMeans(
|
_) = clustering_ops.KMeans(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._distance_metric,
|
self._distance_metric,
|
||||||
@ -263,7 +269,7 @@ class KMeansClustering(estimator.Estimator,
|
|||||||
_,
|
_,
|
||||||
losses,
|
losses,
|
||||||
_) = clustering_ops.KMeans(
|
_) = clustering_ops.KMeans(
|
||||||
features,
|
self._parse_tensor_or_dict(features),
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self._training_initial_clusters,
|
self._training_initial_clusters,
|
||||||
self._distance_metric,
|
self._distance_metric,
|
||||||
|
@ -21,6 +21,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/grid_rnn_test.py"],
|
srcs = ["python/kernel_tests/grid_rnn_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":grid_rnn_py",
|
":grid_rnn_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
@ -127,6 +127,7 @@ py_test(
|
|||||||
name = "optimizers_test",
|
name = "optimizers_test",
|
||||||
srcs = ["python/layers/optimizers_test.py"],
|
srcs = ["python/layers/optimizers_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"], # http://b/31223979
|
||||||
deps = [
|
deps = [
|
||||||
":layers_py",
|
":layers_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
@ -181,7 +181,7 @@ class _TargetColumn(object):
|
|||||||
weight_tensor, shape=(-1,)))
|
weight_tensor, shape=(-1,)))
|
||||||
return weighted_loss
|
return weighted_loss
|
||||||
|
|
||||||
def training_loss(self, logits, target, features):
|
def training_loss(self, logits, target, features, name="training_loss"):
|
||||||
"""Returns training loss tensor for this head.
|
"""Returns training loss tensor for this head.
|
||||||
|
|
||||||
Training loss is different from the loss reported on the tensorboard as we
|
Training loss is different from the loss reported on the tensorboard as we
|
||||||
@ -197,6 +197,7 @@ class _TargetColumn(object):
|
|||||||
target: either a tensor for labels or in multihead case, a dict of string
|
target: either a tensor for labels or in multihead case, a dict of string
|
||||||
to target tensor.
|
to target tensor.
|
||||||
features: features dict.
|
features: features dict.
|
||||||
|
name: Op name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loss tensor.
|
Loss tensor.
|
||||||
@ -206,10 +207,9 @@ class _TargetColumn(object):
|
|||||||
|
|
||||||
weight_tensor = self.get_weight_tensor(features)
|
weight_tensor = self.get_weight_tensor(features)
|
||||||
if weight_tensor is None:
|
if weight_tensor is None:
|
||||||
return math_ops.reduce_mean(loss_unweighted, name="loss")
|
return math_ops.reduce_mean(loss_unweighted, name=name)
|
||||||
else:
|
|
||||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||||
return math_ops.reduce_mean(loss_weighted, name="loss")
|
return math_ops.reduce_mean(loss_weighted, name=name)
|
||||||
|
|
||||||
def loss(self, logits, target, features):
|
def loss(self, logits, target, features):
|
||||||
"""Returns loss tensor for this head.
|
"""Returns loss tensor for this head.
|
||||||
@ -233,7 +233,6 @@ class _TargetColumn(object):
|
|||||||
weight_tensor = self.get_weight_tensor(features)
|
weight_tensor = self.get_weight_tensor(features)
|
||||||
if weight_tensor is None:
|
if weight_tensor is None:
|
||||||
return math_ops.reduce_mean(loss_unweighted, name="loss")
|
return math_ops.reduce_mean(loss_unweighted, name="loss")
|
||||||
else:
|
|
||||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||||
return math_ops.div(
|
return math_ops.div(
|
||||||
math_ops.reduce_sum(loss_weighted),
|
math_ops.reduce_sum(loss_weighted),
|
||||||
@ -409,8 +408,10 @@ def _run_metrics(predictions, targets, metrics, weights):
|
|||||||
result = {}
|
result = {}
|
||||||
targets = math_ops.cast(targets, predictions.dtype)
|
targets = math_ops.cast(targets, predictions.dtype)
|
||||||
for name, metric in six.iteritems(metrics or {}):
|
for name, metric in six.iteritems(metrics or {}):
|
||||||
result[name] = metrics_lib.run_metric(
|
if weights is not None:
|
||||||
metric, predictions, targets, weights=weights)
|
result[name] = metric(predictions, targets, weights=weights)
|
||||||
|
else:
|
||||||
|
result[name] = metric(predictions, targets)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -299,6 +299,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "metric_spec_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/learn/tests/metric_spec_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":learn",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "experiment_test",
|
name = "experiment_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -35,12 +35,12 @@ from tensorflow.contrib.learn.python.learn.dataframe import *
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import *
|
from tensorflow.contrib.learn.python.learn.estimators import *
|
||||||
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
|
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
|
||||||
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
||||||
from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError
|
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import *
|
from tensorflow.contrib.learn.python.learn.learn_io import *
|
||||||
|
from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError
|
||||||
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
@ -79,8 +79,6 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
|||||||
Both features' `value` must be a `SparseTensor`.
|
Both features' `value` must be a `SparseTensor`.
|
||||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||||
whose `value` is a `Tensor`.
|
whose `value` is a `Tensor`.
|
||||||
- if `feature_columns` is `None`, then `input` must contain only real
|
|
||||||
valued `Tensor`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -211,8 +209,6 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
|||||||
Both features' `value` must be a `SparseTensor`.
|
Both features' `value` must be a `SparseTensor`.
|
||||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||||
whose `value` is a `Tensor`.
|
whose `value` is a `Tensor`.
|
||||||
- if `feature_columns` is `None`, then `input` must contain only real
|
|
||||||
valued `Tensor`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -253,9 +253,11 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
logits = array_ops.reshape(
|
logits = array_ops.reshape(
|
||||||
array_ops.tile(centered_bias[0], [batch_size]),
|
array_ops.tile(centered_bias[0], [batch_size]),
|
||||||
[batch_size, self._target_column.num_label_columns])
|
[batch_size, self._target_column.num_label_columns])
|
||||||
training_loss = self._target_column.training_loss(logits, targets, features)
|
with ops.name_scope(None, "centered_bias", (targets, features)):
|
||||||
# Learn central bias by an optimizer. 0.1 is a convervative lr for a single
|
training_loss = self._target_column.training_loss(
|
||||||
# variable.
|
logits, targets, features)
|
||||||
|
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
|
||||||
|
# single variable.
|
||||||
return training.AdagradOptimizer(0.1).minimize(
|
return training.AdagradOptimizer(0.1).minimize(
|
||||||
training_loss, var_list=centered_bias)
|
training_loss, var_list=centered_bias)
|
||||||
|
|
||||||
|
@ -223,10 +223,13 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||||
dnn_hidden_units=[3, 3])
|
dnn_hidden_units=[3, 3])
|
||||||
|
classifier.fit(input_fn=_input_fn_train, steps=100, monitors=(
|
||||||
classifier.fit(input_fn=_input_fn_train, steps=100)
|
tf.contrib.learn.monitors.CaptureVariable(var_name='loss'),
|
||||||
scores = classifier.evaluate(input_fn=_input_fn_eval,
|
tf.contrib.learn.monitors.CaptureVariable(
|
||||||
steps=100)
|
var_name='centered_bias/training_loss'),
|
||||||
|
tf.contrib.learn.monitors.CaptureVariable(var_name='training_loss'),
|
||||||
|
))
|
||||||
|
scores = classifier.evaluate(input_fn=_input_fn_eval, steps=100)
|
||||||
# If there is no weight column, model should learn y=Not(x). All examples in
|
# If there is no weight column, model should learn y=Not(x). All examples in
|
||||||
# eval data set are y=x. So if weight column is ignored, then accuracy
|
# eval data set are y=x. So if weight column is ignored, then accuracy
|
||||||
# should be zero.
|
# should be zero.
|
||||||
@ -251,8 +254,12 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||||
dnn_hidden_units=[3, 3])
|
dnn_hidden_units=[3, 3])
|
||||||
|
classifier.fit(input_fn=_input_fn_train, steps=100, monitors=(
|
||||||
classifier.fit(input_fn=_input_fn_train, steps=100)
|
tf.contrib.learn.monitors.CaptureVariable(var_name='loss'),
|
||||||
|
tf.contrib.learn.monitors.CaptureVariable(
|
||||||
|
var_name='centered_bias/training_loss'),
|
||||||
|
tf.contrib.learn.monitors.CaptureVariable(var_name='training_loss'),
|
||||||
|
))
|
||||||
scores = classifier.evaluate(input_fn=_input_fn_train, steps=100)
|
scores = classifier.evaluate(input_fn=_input_fn_train, steps=100)
|
||||||
# If weight column is ignored, then accuracy should be 0.25. If it's not
|
# If weight column is ignored, then accuracy should be 0.25. If it's not
|
||||||
# ignored, then it should be greater than 0.6.
|
# ignored, then it should be greater than 0.6.
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.contrib.framework import deprecated
|
|||||||
from tensorflow.contrib.framework import deprecated_arg_values
|
from tensorflow.contrib.framework import deprecated_arg_values
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||||
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
@ -52,7 +53,6 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import device_setter
|
from tensorflow.python.training import device_setter
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
@ -174,6 +174,76 @@ def _get_replica_device_setter(config):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _make_metrics_ops(metrics, features, targets, predictions):
|
||||||
|
"""Add metrics to run on features, targets, and predictions dicts or tensors.
|
||||||
|
|
||||||
|
`metrics` contains a specification for how to run metrics. It is a dict
|
||||||
|
mapping friendly names to either `MetricSpec` objects, or directly to a metric
|
||||||
|
function (assuming that predictions and targets are single tensors), or to
|
||||||
|
a `(pred_name, metric)` tuples, which passes `predictions[pred_name]` and
|
||||||
|
targets to `metric` (assuming targets is a single tensor).
|
||||||
|
|
||||||
|
Users are encouraged to use `MetricSpec` objects, which are more flexible and
|
||||||
|
cleaner. They also lead to clearer errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: A dict mapping names to metrics specification, for example
|
||||||
|
`MetricSpec` objects.
|
||||||
|
features: A dict of tensors returned from an input_fn as features/inputs.
|
||||||
|
targets: A single tensor or a dict of tensors returned from an input_fn as
|
||||||
|
labels.
|
||||||
|
predictions: A single tensor or a dict of tensors output from a model as
|
||||||
|
predictions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict mapping the friendly given in `metrics` to the result of calling the
|
||||||
|
given metric function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If metrics specifications do not work with the type of
|
||||||
|
features/targets/predictions provided. Mostly, a dict is given but no
|
||||||
|
pred_name specified.
|
||||||
|
"""
|
||||||
|
metrics = metrics or {}
|
||||||
|
if isinstance(targets, dict) and len(targets) == 1:
|
||||||
|
# Unpack single target into just tensor.
|
||||||
|
targets = targets[list(targets.keys())[0]]
|
||||||
|
result = {}
|
||||||
|
for name, metric in six.iteritems(metrics):
|
||||||
|
if isinstance(metric, metric_spec.MetricSpec):
|
||||||
|
result[name] = metric.create_metric_ops(features, targets, predictions)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO(b/31229024): Remove the rest of this loop
|
||||||
|
logging.warning('Please specify metrics using MetricSpec. Using bare '
|
||||||
|
'functions or (key, fn) tuples is deprecated and support '
|
||||||
|
'for it will be removed on Oct 1, 2016.')
|
||||||
|
|
||||||
|
if isinstance(name, tuple):
|
||||||
|
# Multi-head metrics.
|
||||||
|
if not isinstance(predictions, dict):
|
||||||
|
raise ValueError(
|
||||||
|
'Metrics passed provide (name, prediction), '
|
||||||
|
'but predictions are not dict. '
|
||||||
|
'Metrics: %s, Predictions: %s.' % (metrics, predictions))
|
||||||
|
# Here are two options: targets are single Tensor or a dict.
|
||||||
|
if isinstance(targets, dict) and name[1] in targets:
|
||||||
|
# If targets are dict and the prediction name is in it, apply metric.
|
||||||
|
result[name[0]] = metric(predictions[name[1]], targets[name[1]])
|
||||||
|
else:
|
||||||
|
# Otherwise pass the targets to the metric.
|
||||||
|
result[name[0]] = metric(predictions[name[1]], targets)
|
||||||
|
else:
|
||||||
|
# Single head metrics.
|
||||||
|
if isinstance(predictions, dict):
|
||||||
|
raise ValueError(
|
||||||
|
'Metrics passed provide only name, no prediction, '
|
||||||
|
'but predictions are dict. '
|
||||||
|
'Metrics: %s, Targets: %s.' % (metrics, targets))
|
||||||
|
result[name] = metric(predictions, targets)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseEstimator(
|
class BaseEstimator(
|
||||||
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
||||||
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
|
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
|
||||||
@ -389,7 +459,7 @@ class BaseEstimator(
|
|||||||
'The signature of the input_fn accepted by export is changing to be '
|
'The signature of the input_fn accepted by export is changing to be '
|
||||||
'consistent with what\'s used by tf.Learn Estimator\'s train/evaluate. '
|
'consistent with what\'s used by tf.Learn Estimator\'s train/evaluate. '
|
||||||
'input_fn and input_feature_key will become required args, '
|
'input_fn and input_feature_key will become required args, '
|
||||||
'and use_deprecated_input_fn will default to False & be removed '
|
'and use_deprecated_input_fn will default to False and be removed '
|
||||||
'altogether.',
|
'altogether.',
|
||||||
use_deprecated_input_fn=True,
|
use_deprecated_input_fn=True,
|
||||||
input_fn=None,
|
input_fn=None,
|
||||||
@ -470,15 +540,14 @@ class BaseEstimator(
|
|||||||
Args:
|
Args:
|
||||||
features: `Tensor` or `dict` of `Tensor` objects.
|
features: `Tensor` or `dict` of `Tensor` objects.
|
||||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||||
metrics: Dict of metric ops to run. If None, the default metric functions
|
metrics: Dict of metrics to run. If None, the default metric functions
|
||||||
are used; if {}, no metrics are used. If model has one output (i.e.,
|
are used; if {}, no metrics are used. Otherwise, `metrics` should map
|
||||||
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
|
friendly names for the metric to a `MetricSpec` object defining which
|
||||||
name of the metric that will show up in the logs / summaries.
|
model outputs to evaluate against which targets with which metric
|
||||||
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
|
function. Metric ops should support streaming, e.g., returning
|
||||||
- name of the metric and name of `Tensor` in the predictions to run
|
|
||||||
this metric on. Metric ops should support streaming, e.g., returning
|
|
||||||
update_op and value tensors. See more details in
|
update_op and value tensors. See more details in
|
||||||
../../../../metrics/python/metrics/ops/streaming_metrics.py.
|
`../../../../metrics/python/metrics/ops/streaming_metrics.py` and
|
||||||
|
`../metric_spec.py`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics: `dict` of `Tensor` objects.
|
metrics: `dict` of `Tensor` objects.
|
||||||
@ -782,8 +851,7 @@ class Estimator(BaseEstimator):
|
|||||||
model_fn=None,
|
model_fn=None,
|
||||||
model_dir=None,
|
model_dir=None,
|
||||||
config=None,
|
config=None,
|
||||||
params=None,
|
params=None):
|
||||||
weight_column_name=None):
|
|
||||||
"""Constructs an Estimator instance.
|
"""Constructs an Estimator instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -816,9 +884,6 @@ class Estimator(BaseEstimator):
|
|||||||
config: Configuration object.
|
config: Configuration object.
|
||||||
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
||||||
Keys are names of parameters, values are basic python types.
|
Keys are names of parameters, values are basic python types.
|
||||||
weight_column_name: A string defining feature column name representing
|
|
||||||
weights. It is used to down weight or boost examples during training. It
|
|
||||||
will be multiplied by the loss of the example.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: parameters of `model_fn` don't match `params`.
|
ValueError: parameters of `model_fn` don't match `params`.
|
||||||
@ -831,17 +896,10 @@ class Estimator(BaseEstimator):
|
|||||||
raise ValueError('Estimator\'s model_fn (%s) has less than 4 '
|
raise ValueError('Estimator\'s model_fn (%s) has less than 4 '
|
||||||
'arguments, but not None params (%s) are passed.' %
|
'arguments, but not None params (%s) are passed.' %
|
||||||
(model_fn, params))
|
(model_fn, params))
|
||||||
if (params is None and weight_column_name is None and
|
if params is None and 'params' in model_fn_args:
|
||||||
'params' in model_fn_args):
|
|
||||||
logging.warning('Estimator\'s model_fn (%s) has includes params '
|
logging.warning('Estimator\'s model_fn (%s) has includes params '
|
||||||
'argument, but params are not passed to Estimator.',
|
'argument, but params are not passed to Estimator.',
|
||||||
model_fn)
|
model_fn)
|
||||||
self.weight_column_name = weight_column_name
|
|
||||||
if weight_column_name is not None:
|
|
||||||
if params is None:
|
|
||||||
params = {'weight_column_name': weight_column_name}
|
|
||||||
else:
|
|
||||||
params['weight_column_name'] = weight_column_name
|
|
||||||
self._model_fn = model_fn
|
self._model_fn = model_fn
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
@ -855,11 +913,6 @@ class Estimator(BaseEstimator):
|
|||||||
return self._model_fn(features, targets, mode=mode)
|
return self._model_fn(features, targets, mode=mode)
|
||||||
return self._model_fn(features, targets)
|
return self._model_fn(features, targets)
|
||||||
|
|
||||||
def _get_weight_tensor(self, features):
|
|
||||||
if not self.weight_column_name:
|
|
||||||
return None
|
|
||||||
return math_ops.to_float(features[self.weight_column_name])
|
|
||||||
|
|
||||||
def _get_train_ops(self, features, targets):
|
def _get_train_ops(self, features, targets):
|
||||||
"""Method that builds model graph and returns trainer ops.
|
"""Method that builds model graph and returns trainer ops.
|
||||||
|
|
||||||
@ -887,15 +940,14 @@ class Estimator(BaseEstimator):
|
|||||||
Args:
|
Args:
|
||||||
features: `Tensor` or `dict` of `Tensor` objects.
|
features: `Tensor` or `dict` of `Tensor` objects.
|
||||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||||
metrics: Dict of metric ops to run. If None, the default metric functions
|
metrics: Dict of metrics to run. If None, the default metric functions
|
||||||
are used; if {}, no metrics are used. If model has one output (i.e.,
|
are used; if {}, no metrics are used. Otherwise, `metrics` should map
|
||||||
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
|
friendly names for the metric to a `MetricSpec` object defining which
|
||||||
name of the metric that will show up in the logs / summaries.
|
model outputs to evaluate against which targets with which metric
|
||||||
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
|
function. Metric ops should support streaming, e.g., returning
|
||||||
- name of the metric and name of `Tensor` in the predictions to run
|
|
||||||
this metric on. Metric ops should support streaming, e.g., returning
|
|
||||||
update_op and value tensors. See more details in
|
update_op and value tensors. See more details in
|
||||||
../../../../metrics/python/metrics/ops/streaming_metrics.py.
|
`../../../../metrics/python/metrics/ops/streaming_metrics.py` and
|
||||||
|
`../metric_spec.py`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics: `dict` of `Tensor` objects.
|
metrics: `dict` of `Tensor` objects.
|
||||||
@ -905,38 +957,7 @@ class Estimator(BaseEstimator):
|
|||||||
"""
|
"""
|
||||||
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
|
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
|
||||||
result = {'loss': metrics_lib.streaming_mean(loss)}
|
result = {'loss': metrics_lib.streaming_mean(loss)}
|
||||||
|
result.update(_make_metrics_ops(metrics, features, targets, predictions))
|
||||||
weights = self._get_weight_tensor(features)
|
|
||||||
metrics = metrics or {}
|
|
||||||
if isinstance(targets, dict) and len(targets) == 1:
|
|
||||||
# Unpack single target into just tensor.
|
|
||||||
targets = targets[list(targets.keys())[0]]
|
|
||||||
for name, metric in six.iteritems(metrics):
|
|
||||||
if isinstance(name, tuple):
|
|
||||||
# Multi-head metrics.
|
|
||||||
if not isinstance(predictions, dict):
|
|
||||||
raise ValueError(
|
|
||||||
'Metrics passed provide (name, prediction), '
|
|
||||||
'but predictions are not dict. '
|
|
||||||
'Metrics: %s, Predictions: %s.' % (metrics, predictions))
|
|
||||||
# Here are two options: targets are single Tensor or a dict.
|
|
||||||
if isinstance(targets, dict) and name[1] in targets:
|
|
||||||
# If targets are dict and the prediction name is in it, apply metric.
|
|
||||||
result[name[0]] = metrics_lib.run_metric(
|
|
||||||
metric, predictions[name[1]], targets[name[1]], weights)
|
|
||||||
else:
|
|
||||||
# Otherwise pass the targets to the metric.
|
|
||||||
result[name[0]] = metrics_lib.run_metric(
|
|
||||||
metric, predictions[name[1]], targets, weights)
|
|
||||||
else:
|
|
||||||
# Single head metrics.
|
|
||||||
if isinstance(predictions, dict):
|
|
||||||
raise ValueError(
|
|
||||||
'Metrics passed provide only name, no prediction, '
|
|
||||||
'but predictions are dict. '
|
|
||||||
'Metrics: %s, Targets: %s.' % (metrics, targets))
|
|
||||||
result[name] = metrics_lib.run_metric(
|
|
||||||
metric, predictions, targets, weights)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
|
@ -44,16 +44,6 @@ def boston_input_fn(num_epochs=None):
|
|||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
|
|
||||||
def boston_input_with_weight_fn():
|
|
||||||
boston = tf.contrib.learn.datasets.load_boston()
|
|
||||||
features = {}
|
|
||||||
features['data'] = tf.reshape(
|
|
||||||
tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
|
|
||||||
target = tf.reshape(tf.constant(boston.target), [-1, 1])
|
|
||||||
features['weight'] = tf.mul(0.5, tf.ones(target.get_shape()))
|
|
||||||
return features, target
|
|
||||||
|
|
||||||
|
|
||||||
def iris_input_fn():
|
def iris_input_fn():
|
||||||
iris = tf.contrib.learn.datasets.load_iris()
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
|
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
|
||||||
@ -92,42 +82,6 @@ def linear_model_fn(features, target, mode):
|
|||||||
return prediction, loss, train_op
|
return prediction, loss, train_op
|
||||||
|
|
||||||
|
|
||||||
def linear_model_with_weights_fn(features, target, mode):
|
|
||||||
assert mode in ('train', 'eval', 'infer')
|
|
||||||
prediction, loss = (
|
|
||||||
tf.contrib.learn.models.linear_regression_zero_init(
|
|
||||||
features['data'], target)
|
|
||||||
)
|
|
||||||
train_op = tf.contrib.layers.optimize_loss(
|
|
||||||
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
|
|
||||||
learning_rate=0.1)
|
|
||||||
return prediction, loss, train_op
|
|
||||||
|
|
||||||
|
|
||||||
def linear_model_with_weights_and_params_fn(features, target, mode, params):
|
|
||||||
assert mode in ('train', 'eval', 'infer')
|
|
||||||
prediction, loss = (
|
|
||||||
tf.contrib.learn.models.linear_regression_zero_init(
|
|
||||||
features['data'], target)
|
|
||||||
)
|
|
||||||
train_op = tf.contrib.layers.optimize_loss(
|
|
||||||
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
|
|
||||||
learning_rate=params['learning_rate'])
|
|
||||||
return prediction, loss, train_op
|
|
||||||
|
|
||||||
|
|
||||||
def squared_error_weighted_sum(predictions, targets, weights=None):
|
|
||||||
squared_error = tf.to_float(tf.square(predictions - targets))
|
|
||||||
if weights is None:
|
|
||||||
return tf.reduce_sum(squared_error)
|
|
||||||
else:
|
|
||||||
return tf.reduce_sum(tf.mul(squared_error, weights))
|
|
||||||
|
|
||||||
|
|
||||||
def squared_error_no_weight(predictions, targets):
|
|
||||||
return squared_error_weighted_sum(predictions, targets)
|
|
||||||
|
|
||||||
|
|
||||||
def logistic_model_no_mode_fn(features, target):
|
def logistic_model_no_mode_fn(features, target):
|
||||||
target = tf.one_hot(target, 3, 1, 0)
|
target = tf.one_hot(target, 3, 1, 0)
|
||||||
prediction, loss = (
|
prediction, loss = (
|
||||||
@ -384,40 +338,6 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
est.fit(input_fn=other_input_fn, steps=1)
|
est.fit(input_fn=other_input_fn, steps=1)
|
||||||
|
|
||||||
def testEstimatorWithWeight(self):
|
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_with_weights_fn,
|
|
||||||
weight_column_name='weight')
|
|
||||||
self.assertTrue(est.params is not None)
|
|
||||||
self.assertTrue('weight_column_name' in est.params)
|
|
||||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
|
||||||
scores = est.evaluate(
|
|
||||||
input_fn=boston_input_with_weight_fn, steps=100,
|
|
||||||
metrics={'SEWS': squared_error_weighted_sum,
|
|
||||||
'SE': squared_error_no_weight})
|
|
||||||
self.assertNear(scores['SEWS']*2, scores['SE'], 0.01)
|
|
||||||
|
|
||||||
def testEstimatorWithWeightAndParams(self):
|
|
||||||
est = tf.contrib.learn.Estimator(
|
|
||||||
model_fn=linear_model_with_weights_and_params_fn,
|
|
||||||
params={'learning_rate': 0.01},
|
|
||||||
weight_column_name='weight')
|
|
||||||
self.assertTrue('weight_column_name' in est.params)
|
|
||||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
|
||||||
scores = est.evaluate(
|
|
||||||
input_fn=boston_input_with_weight_fn, steps=100,
|
|
||||||
metrics={'SEWS': squared_error_weighted_sum,
|
|
||||||
'SE': squared_error_no_weight})
|
|
||||||
self.assertNear(scores['SEWS']*2, scores['SE'], 0.01)
|
|
||||||
|
|
||||||
def testEstimatorWithNoWeight(self):
|
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_with_weights_fn)
|
|
||||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
|
||||||
scores = est.evaluate(
|
|
||||||
input_fn=boston_input_with_weight_fn, steps=100,
|
|
||||||
metrics={'SEWS': squared_error_weighted_sum,
|
|
||||||
'SE': squared_error_no_weight})
|
|
||||||
self.assertNear(scores['SEWS'], scores['SE'], 0.01)
|
|
||||||
|
|
||||||
def testMonitors(self):
|
def testMonitors(self):
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
|
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
|
||||||
est.fit(input_fn=boston_input_fn,
|
est.fit(input_fn=boston_input_fn,
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib import metrics as metrics_lib
|
|||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
from tensorflow.contrib.layers.python.layers import target_column
|
from tensorflow.contrib.layers.python.layers import target_column
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
|
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
@ -70,7 +71,7 @@ def _wrap_metric(metric):
|
|||||||
targets = math_ops.cast(targets, preds.dtype)
|
targets = math_ops.cast(targets, preds.dtype)
|
||||||
return metric(preds, targets)
|
return metric(preds, targets)
|
||||||
|
|
||||||
def wrapped_weights(preds, targets, weights):
|
def wrapped_weights(preds, targets, weights=None):
|
||||||
targets = math_ops.cast(targets, preds.dtype)
|
targets = math_ops.cast(targets, preds.dtype)
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
|
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
|
||||||
@ -264,6 +265,7 @@ def sdca_classifier_model_fn(features, targets, mode, params):
|
|||||||
loss = None
|
loss = None
|
||||||
if mode != estimator.ModeKeys.INFER:
|
if mode != estimator.ModeKeys.INFER:
|
||||||
loss = math_ops.reduce_mean(loss_fn(logits, targets), name="loss")
|
loss = math_ops.reduce_mean(loss_fn(logits, targets), name="loss")
|
||||||
|
logging_ops.scalar_summary("loss", loss)
|
||||||
|
|
||||||
train_op = None
|
train_op = None
|
||||||
if mode == estimator.ModeKeys.TRAIN:
|
if mode == estimator.ModeKeys.TRAIN:
|
||||||
@ -347,8 +349,6 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Both features' `value` must be a `SparseTensor`.
|
Both features' `value` must be a `SparseTensor`.
|
||||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||||
whose `value` is a `Tensor`.
|
whose `value` is a `Tensor`.
|
||||||
- if `feature_columns` is `None`, then `input` must contains only real
|
|
||||||
valued `Tensor`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -426,8 +426,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
model_fn=model_fn,
|
model_fn=model_fn,
|
||||||
model_dir=self._model_dir,
|
model_dir=self._model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
params=params,
|
params=params)
|
||||||
weight_column_name=weight_column_name)
|
|
||||||
|
|
||||||
def get_estimator(self):
|
def get_estimator(self):
|
||||||
return self._estimator
|
return self._estimator
|
||||||
@ -445,14 +444,24 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
"""See evaluable.Evaluable."""
|
"""See evaluable.Evaluable."""
|
||||||
if not metrics:
|
if not metrics:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
metrics[("accuracy", _CLASSES)] = metrics_lib.streaming_accuracy
|
metrics["accuracy"] = metric_spec.MetricSpec(
|
||||||
|
metric_fn=metrics_lib.streaming_accuracy,
|
||||||
|
prediction_key=_CLASSES)
|
||||||
if self._n_classes == 2:
|
if self._n_classes == 2:
|
||||||
additional_metrics = (
|
additional_metrics = (
|
||||||
target_column.get_default_binary_metrics_for_eval([0.5]))
|
target_column.get_default_binary_metrics_for_eval([0.5]))
|
||||||
additional_metrics = {(name, _LOGISTIC): metric
|
additional_metrics = {
|
||||||
for name, metric in additional_metrics.items()}
|
name: metric_spec.MetricSpec(metric_fn=metric,
|
||||||
|
prediction_key=_LOGISTIC)
|
||||||
|
for name, metric in additional_metrics.items()
|
||||||
|
}
|
||||||
metrics.update(additional_metrics)
|
metrics.update(additional_metrics)
|
||||||
|
|
||||||
|
# TODO(b/31229024): Remove this loop
|
||||||
for metric_name, metric in metrics.items():
|
for metric_name, metric in metrics.items():
|
||||||
|
if isinstance(metric, metric_spec.MetricSpec):
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(metric_name, tuple):
|
if isinstance(metric_name, tuple):
|
||||||
if len(metric_name) != 2:
|
if len(metric_name) != 2:
|
||||||
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
||||||
@ -577,8 +586,6 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
|||||||
key=weight column name, value=a `SparseTensor`}
|
key=weight column name, value=a `SparseTensor`}
|
||||||
- if isinstance(column, `RealValuedColumn`):
|
- if isinstance(column, `RealValuedColumn`):
|
||||||
key=column.name, value=a `Tensor`
|
key=column.name, value=a `Tensor`
|
||||||
- if `feature_columns` is `None`:
|
|
||||||
input must contains only real valued `Tensor`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||||
|
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||||
|
|
||||||
|
|
||||||
def _iris_input_fn():
|
def _iris_input_fn():
|
||||||
@ -137,8 +138,8 @@ class LinearClassifierTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def _input_fn_train():
|
def _input_fn_train():
|
||||||
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
|
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
|
||||||
target = tf.constant([[1], [0], [0], [0]])
|
target = tf.constant([[1], [0], [0], [0]], dtype=tf.float32)
|
||||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
|
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32)}
|
||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
def _my_metric_op(predictions, targets):
|
def _my_metric_op(predictions, targets):
|
||||||
@ -155,9 +156,14 @@ class LinearClassifierTest(tf.test.TestCase):
|
|||||||
input_fn=_input_fn_train,
|
input_fn=_input_fn_train,
|
||||||
steps=100,
|
steps=100,
|
||||||
metrics={
|
metrics={
|
||||||
('my_accuracy', 'classes'): tf.contrib.metrics.streaming_accuracy,
|
'my_accuracy': MetricSpec(
|
||||||
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
|
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||||
('my_metric', 'probabilities'): _my_metric_op
|
prediction_key='classes'),
|
||||||
|
'my_precision': MetricSpec(
|
||||||
|
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||||
|
prediction_key='classes'),
|
||||||
|
'my_metric': MetricSpec(metric_fn=_my_metric_op,
|
||||||
|
prediction_key='probabilities')
|
||||||
})
|
})
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
||||||
|
@ -25,13 +25,12 @@ from tensorflow.contrib import layers
|
|||||||
from tensorflow.contrib import metrics as metrics_lib
|
from tensorflow.contrib import metrics as metrics_lib
|
||||||
from tensorflow.contrib.layers.python.layers import target_column
|
from tensorflow.contrib.layers.python.layers import target_column
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||||
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
|
|
||||||
|
|
||||||
def _as_iterable(preds, output):
|
def _as_iterable(preds, output):
|
||||||
@ -47,21 +46,6 @@ def _get_metric_args(metric):
|
|||||||
if arg not in metric.keywords.keys()]
|
if arg not in metric.keywords.keys()]
|
||||||
|
|
||||||
|
|
||||||
def _wrap_metric(metric):
|
|
||||||
"""Wraps metrics for mismatched prediction/target types."""
|
|
||||||
def wrapped(preds, targets):
|
|
||||||
targets = math_ops.cast(targets, preds.dtype)
|
|
||||||
return metric(preds, targets)
|
|
||||||
|
|
||||||
def wrapped_weights(preds, targets, weights):
|
|
||||||
targets = math_ops.cast(targets, preds.dtype)
|
|
||||||
if weights is not None:
|
|
||||||
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
|
|
||||||
return metric(preds, targets, weights)
|
|
||||||
|
|
||||||
return wrapped_weights if "weights" in _get_metric_args(metric) else wrapped
|
|
||||||
|
|
||||||
|
|
||||||
class SVM(trainable.Trainable, evaluable.Evaluable):
|
class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||||
"""Support Vector Machine (SVM) model for binary classification.
|
"""Support Vector Machine (SVM) model for binary classification.
|
||||||
|
|
||||||
@ -100,9 +84,6 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
whose `value` is a `SparseTensor`.
|
whose `value` is a `SparseTensor`.
|
||||||
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
|
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
|
||||||
whose `value` is a `Tensor`.
|
whose `value` is a `Tensor`.
|
||||||
- if `feature_columns` is None, then `input` must contains only real
|
|
||||||
valued `Tensor`.
|
|
||||||
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
example_id_column: A string defining the feature column name representing
|
example_id_column: A string defining the feature column name representing
|
||||||
@ -166,15 +147,24 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
batch_size=None, steps=None, metrics=None, name=None):
|
batch_size=None, steps=None, metrics=None, name=None):
|
||||||
"""See evaluable.Evaluable."""
|
"""See evaluable.Evaluable."""
|
||||||
if not metrics:
|
if not metrics:
|
||||||
metrics = {
|
metrics = {}
|
||||||
("accuracy", linear._CLASSES): metrics_lib.streaming_accuracy,
|
metrics["accuracy"] = metric_spec.MetricSpec(
|
||||||
}
|
metric_fn=metrics_lib.streaming_accuracy,
|
||||||
|
prediction_key=linear._CLASSES)
|
||||||
additional_metrics = (
|
additional_metrics = (
|
||||||
target_column.get_default_binary_metrics_for_eval([0.5]))
|
target_column.get_default_binary_metrics_for_eval([0.5]))
|
||||||
additional_metrics = {(name, linear._LOGISTIC): metric
|
additional_metrics = {
|
||||||
for name, metric in additional_metrics.items()}
|
name: metric_spec.MetricSpec(metric_fn=metric,
|
||||||
|
prediction_key=linear._LOGISTIC)
|
||||||
|
for name, metric in additional_metrics.items()
|
||||||
|
}
|
||||||
metrics.update(additional_metrics)
|
metrics.update(additional_metrics)
|
||||||
|
|
||||||
|
# TODO(b/31229024): Remove this loop
|
||||||
for metric_name, metric in metrics.items():
|
for metric_name, metric in metrics.items():
|
||||||
|
if isinstance(metric, metric_spec.MetricSpec):
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(metric_name, tuple):
|
if isinstance(metric_name, tuple):
|
||||||
if len(metric_name) != 2:
|
if len(metric_name) != 2:
|
||||||
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
||||||
@ -184,7 +174,7 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
if metric_name[1] not in valid_keys:
|
if metric_name[1] not in valid_keys:
|
||||||
raise ValueError("Ignoring metric %s. The 2nd element of its name "
|
raise ValueError("Ignoring metric %s. The 2nd element of its name "
|
||||||
"should be in %s" % (metric_name, valid_keys))
|
"should be in %s" % (metric_name, valid_keys))
|
||||||
metrics[metric_name] = _wrap_metric(metric)
|
metrics[metric_name] = linear._wrap_metric(metric)
|
||||||
return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
|
return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
|
||||||
feed_fn=feed_fn, batch_size=batch_size,
|
feed_fn=feed_fn, batch_size=batch_size,
|
||||||
steps=steps, metrics=metrics, name=name)
|
steps=steps, metrics=metrics, name=name)
|
||||||
|
186
tensorflow/contrib/learn/python/learn/metric_spec.py
Normal file
186
tensorflow/contrib/learn/python/learn/metric_spec.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The metric spec class to flexibly connect models and metrics."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class MetricSpec(object):
|
||||||
|
"""MetricSpec connects a model to metric functions.
|
||||||
|
|
||||||
|
The MetricSpec class contains all information necessary to connect the
|
||||||
|
output of a `model_fn` to the metrics (usually, streaming metrics) that are
|
||||||
|
used in evaluation.
|
||||||
|
|
||||||
|
It is passed in the `metrics` argument of `Estimator.evaluate`. The
|
||||||
|
`Estimator` then knows which predictions, labels, and weight to use to call a
|
||||||
|
given metric function.
|
||||||
|
|
||||||
|
When building the ops to run in evaluation, `Estimator` will call
|
||||||
|
`create_metric_ops`, which will connect the given `metric_fn` to the model
|
||||||
|
as detailed in the docstring for `create_metric_ops`, and return the metric.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
Assuming an model has an input function which returns inputs containing
|
||||||
|
(among other things) a tensor with key "income", and a labels dictionary
|
||||||
|
containing "has_clicked". Let's assume that the `model_fn` for this model
|
||||||
|
returns a prediction with key "clicked".
|
||||||
|
|
||||||
|
In order to compute the accuracy of the "clicked" prediction, we would add
|
||||||
|
```
|
||||||
|
"click accuracy": MetricSpec(metric_fn=streaming_accuracy,
|
||||||
|
prediction_key="clicked",
|
||||||
|
label_key="has_clicked")
|
||||||
|
```
|
||||||
|
to the metrics argument to `evaluate`. If we would like the accuracy to be
|
||||||
|
weighted by "income", we can add that as the `weight_key` argument.
|
||||||
|
```
|
||||||
|
"click accuracy": MetricSpec(metric_fn=streaming_accuracy,
|
||||||
|
prediction_key="clicked",
|
||||||
|
label_key="has_clicked",
|
||||||
|
weight_key="income")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
metric_fn,
|
||||||
|
prediction_key=None,
|
||||||
|
label_key=None,
|
||||||
|
weight_key=None):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Creates a MetricSpec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric_fn: A function to use as a metric. Must accept `predictions`,
|
||||||
|
`labels` and optionally, `weights` tensors as inputs, and must return
|
||||||
|
either a single tensor which is interpreted as a value of this metric,
|
||||||
|
or a pair `(value_op, update_op)`, where value_op is the op to call to
|
||||||
|
obtain the value of the metric, and update_op should be evaluated for
|
||||||
|
each batch in order to update internal state.
|
||||||
|
prediction_key: The key for a tensor in the `predictions` dict (output
|
||||||
|
from the `model_fn`) to use as the `predictions` input to the
|
||||||
|
`metric_fn`. Optional. If `None`, the `model_fn` must return a single
|
||||||
|
tensor or a dict with only a single entry as `predictions`.
|
||||||
|
label_key: The key for a tensor in the `labels` dict (output from the
|
||||||
|
`input_fn`) to use as the `labels` input to the `metric_fn`.
|
||||||
|
Optional. If `None`, the `input_fn` must return a single tensor or a
|
||||||
|
dict with only a single entry as `labels`.
|
||||||
|
weight_key: The key for a tensor in the `inputs` dict (output from the
|
||||||
|
`input_fn`) to use as the `weights` input to the `metric_fn`.
|
||||||
|
Optional. If `None`, no weights will be passed to the `metric_fn`.
|
||||||
|
"""
|
||||||
|
self._metric_fn = metric_fn
|
||||||
|
self._prediction_key = prediction_key
|
||||||
|
self._label_key = label_key
|
||||||
|
self._weight_key = weight_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prediction_key(self):
|
||||||
|
return self._prediction_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label_key(self):
|
||||||
|
return self._label_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_key(self):
|
||||||
|
return self._weight_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric_fn(self):
|
||||||
|
return self._metric_fn
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return ('MetricSpec(metric_fn=%s, ' % self.metric_fn.__name__ +
|
||||||
|
'prediction_key=%s, ' % self.prediction_key +
|
||||||
|
'label_key=%s, ' % self.label_key +
|
||||||
|
'weight_key=%s)' % self.weight_key
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_metric_ops(self, inputs, labels, predictions):
|
||||||
|
"""Connect our `metric_fn` to the specified members of the given dicts.
|
||||||
|
|
||||||
|
This function will call the `metric_fn` given in our constructor as follows:
|
||||||
|
```
|
||||||
|
metric_fn(predictions[self.prediction_key],
|
||||||
|
labels[self.label_key],
|
||||||
|
weights=weights[self.weight_key])
|
||||||
|
```
|
||||||
|
And returns the result. The `weights` argument is only passed if
|
||||||
|
`self.weight_key` is not `None`.
|
||||||
|
|
||||||
|
`predictions` and `labels` may be single tensors as well as dicts. If
|
||||||
|
`predictions` is a single tensor, `self.prediction_key` must be `None`. If
|
||||||
|
`predictions` is a single element dict, `self.prediction_key` is allowed to
|
||||||
|
be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must
|
||||||
|
be `None`. If `labels` is a single element dict, `self.label_key` is allowed
|
||||||
|
to be `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dict of inputs produced by the `input_fn`
|
||||||
|
labels: A dict of labels or a single label tensor produced by the
|
||||||
|
`input_fn`.
|
||||||
|
predictions: A dict of predictions or a single tensor produced by the
|
||||||
|
`model_fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of calling `metric_fn`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `predictions` or `labels` is a single `Tensor` and
|
||||||
|
`self.prediction_key` or `self.label_key` is not `None`; or if
|
||||||
|
`self.label_key` is `None` but `labels` is a dict with more than one
|
||||||
|
element, or if `self.prediction_key` is `None but `predictions` is a
|
||||||
|
dict with more than one element.
|
||||||
|
"""
|
||||||
|
def _get_dict(name, dict_or_tensor, key):
|
||||||
|
"""Get a single tensor or an element of a dict or raise ValueError."""
|
||||||
|
if key:
|
||||||
|
if not isinstance(dict_or_tensor, dict):
|
||||||
|
raise ValueError('MetricSpec with ' + name + '_key specified'
|
||||||
|
' requires ' +
|
||||||
|
name + 's dict, got %s' % dict_or_tensor)
|
||||||
|
return dict_or_tensor[key]
|
||||||
|
else:
|
||||||
|
if isinstance(dict_or_tensor, dict):
|
||||||
|
if len(dict_or_tensor) != 1:
|
||||||
|
raise ValueError('MetricSpec without specified ' + name + '_key'
|
||||||
|
' requires ' + name + 's tensor or single element'
|
||||||
|
' dict, got %s' % dict_or_tensor)
|
||||||
|
return dict_or_tensor.values()[0]
|
||||||
|
else:
|
||||||
|
return dict_or_tensor
|
||||||
|
|
||||||
|
# Get the predictions
|
||||||
|
prediction = _get_dict('prediction', predictions, self.prediction_key)
|
||||||
|
|
||||||
|
# Get the labels
|
||||||
|
label = _get_dict('label', labels, self.label_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.weight_key:
|
||||||
|
return self.metric_fn(prediction, label,
|
||||||
|
weights=inputs[self.weight_key])
|
||||||
|
else:
|
||||||
|
return self.metric_fn(prediction, label)
|
||||||
|
except: # pylint: disable=bare-except
|
||||||
|
logging.error('Could not create metric ops for %s.' % self)
|
||||||
|
raise
|
150
tensorflow/contrib/learn/python/learn/tests/metric_spec_test.py
Normal file
150
tensorflow/contrib/learn/python/learn/tests/metric_spec_test.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for MetricSpec."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric(predictions, labels, weights=None):
|
||||||
|
return predictions, labels, weights
|
||||||
|
|
||||||
|
|
||||||
|
class MetricSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_create_metric_ops(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
passed = MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
label_key="label1",
|
||||||
|
weight_key="feature2").create_metric_ops(features,
|
||||||
|
labels,
|
||||||
|
predictions)
|
||||||
|
|
||||||
|
self.assertEqual(passed[0], "pred1_tensor")
|
||||||
|
self.assertEqual(passed[1], "label1_tensor")
|
||||||
|
self.assertEqual(passed[2], "feature2_tensor")
|
||||||
|
|
||||||
|
def test_no_weight(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
passed = MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
label_key="label1").create_metric_ops(features, labels,
|
||||||
|
predictions)
|
||||||
|
|
||||||
|
self.assertEqual(passed[0], "pred1_tensor")
|
||||||
|
self.assertEqual(passed[1], "label1_tensor")
|
||||||
|
self.assertEqual(passed[2], None)
|
||||||
|
|
||||||
|
def test_fail_no_prediction(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
self.assertRaisesRegexp(ValueError,
|
||||||
|
"MetricSpec without specified prediction_key "
|
||||||
|
"requires predictions tensor or single element "
|
||||||
|
"dict, got",
|
||||||
|
MetricSpec(metric_fn=test_metric,
|
||||||
|
label_key="label1",
|
||||||
|
weight_key="feature2").create_metric_ops,
|
||||||
|
features, labels, predictions)
|
||||||
|
|
||||||
|
def test_fail_no_label(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
self.assertRaisesRegexp(ValueError,
|
||||||
|
"MetricSpec without specified label_key requires "
|
||||||
|
"labels tensor or single element dict, got",
|
||||||
|
MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
weight_key="feature2").create_metric_ops,
|
||||||
|
features, labels, predictions)
|
||||||
|
|
||||||
|
def test_single_prediction(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = "pred1_tensor"
|
||||||
|
|
||||||
|
passed = MetricSpec(metric_fn=test_metric,
|
||||||
|
label_key="label1",
|
||||||
|
weight_key="feature2").create_metric_ops(features,
|
||||||
|
labels,
|
||||||
|
predictions)
|
||||||
|
|
||||||
|
self.assertEqual(passed[0], "pred1_tensor")
|
||||||
|
self.assertEqual(passed[1], "label1_tensor")
|
||||||
|
self.assertEqual(passed[2], "feature2_tensor")
|
||||||
|
|
||||||
|
def test_single_label(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = "label1_tensor"
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
passed = MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
weight_key="feature2").create_metric_ops(features,
|
||||||
|
labels,
|
||||||
|
predictions)
|
||||||
|
|
||||||
|
self.assertEqual(passed[0], "pred1_tensor")
|
||||||
|
self.assertEqual(passed[1], "label1_tensor")
|
||||||
|
self.assertEqual(passed[2], "feature2_tensor")
|
||||||
|
|
||||||
|
def test_fail_single_prediction(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||||
|
predictions = "pred1_tensor"
|
||||||
|
|
||||||
|
self.assertRaisesRegexp(ValueError,
|
||||||
|
"MetricSpec with prediction_key specified requires "
|
||||||
|
"predictions dict, got",
|
||||||
|
MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
label_key="label1",
|
||||||
|
weight_key="feature2").create_metric_ops,
|
||||||
|
features, labels, predictions)
|
||||||
|
|
||||||
|
def test_fail_single_label(self):
|
||||||
|
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||||
|
labels = "label1_tensor"
|
||||||
|
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||||
|
|
||||||
|
self.assertRaisesRegexp(ValueError,
|
||||||
|
"MetricSpec with label_key specified requires "
|
||||||
|
"labels dict, got",
|
||||||
|
MetricSpec(metric_fn=test_metric,
|
||||||
|
prediction_key="pred1",
|
||||||
|
label_key="label1",
|
||||||
|
weight_key="feature2").create_metric_ops,
|
||||||
|
features, labels, predictions)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -367,6 +367,7 @@ class SdcaModel(object):
|
|||||||
|
|
||||||
logging_ops.scalar_summary('approximate_duality_gap',
|
logging_ops.scalar_summary('approximate_duality_gap',
|
||||||
self.approximate_duality_gap())
|
self.approximate_duality_gap())
|
||||||
|
logging_ops.scalar_summary('examples_seen', self._hashtable.size())
|
||||||
|
|
||||||
def _symmetric_l1_regularization(self):
|
def _symmetric_l1_regularization(self):
|
||||||
return self._options['symmetric_l1_regularization']
|
return self._options['symmetric_l1_regularization']
|
||||||
|
@ -116,6 +116,7 @@ weighted average over the individual prediction errors:
|
|||||||
@@mean_squared_error
|
@@mean_squared_error
|
||||||
@@sigmoid_cross_entropy
|
@@sigmoid_cross_entropy
|
||||||
@@softmax_cross_entropy
|
@@softmax_cross_entropy
|
||||||
|
@@sparse_softmax_cross_entropy
|
||||||
|
|
||||||
The following are deprecated in favor of `mean_pairwise_squared_error` and
|
The following are deprecated in favor of `mean_pairwise_squared_error` and
|
||||||
`mean_squared_error`.
|
`mean_squared_error`.
|
||||||
|
@ -41,6 +41,7 @@ __all__ = ["absolute_difference",
|
|||||||
"mean_squared_error",
|
"mean_squared_error",
|
||||||
"sigmoid_cross_entropy",
|
"sigmoid_cross_entropy",
|
||||||
"softmax_cross_entropy",
|
"softmax_cross_entropy",
|
||||||
|
"sparse_softmax_cross_entropy",
|
||||||
"sum_of_pairwise_squares",
|
"sum_of_pairwise_squares",
|
||||||
"sum_of_squares"]
|
"sum_of_squares"]
|
||||||
|
|
||||||
@ -354,8 +355,8 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
|
|||||||
A scalar `Tensor` representing the loss value.
|
A scalar `Tensor` representing the loss value.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
|
||||||
if the shape of `weight` is invalid or if `weight` is None.
|
or if the shape of `weight` is invalid or if `weight` is None.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(scope, "softmax_cross_entropy_loss",
|
with ops.name_scope(scope, "softmax_cross_entropy_loss",
|
||||||
[logits, onehot_labels]):
|
[logits, onehot_labels]):
|
||||||
@ -375,6 +376,39 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
|
|||||||
return _compute_weighted_loss(losses, weight)
|
return _compute_weighted_loss(losses, weight)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None):
|
||||||
|
"""Cross-entropy loss using tf.nn.sparse_softmax_cross_entropy_with_logits.
|
||||||
|
|
||||||
|
`weight` acts as a coefficient for the loss. If a scalar is provided,
|
||||||
|
then the loss is simply scaled by the given value. If `weight` is a
|
||||||
|
tensor of size [`batch_size`], then the loss weights apply to each
|
||||||
|
corresponding sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: [batch_size, num_classes] logits outputs of the network .
|
||||||
|
labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or
|
||||||
|
`int64` in the range `[0, num_classes)`.
|
||||||
|
weight: Coefficients for the loss. The tensor must be a scalar or a tensor
|
||||||
|
of shape [batch_size] or [batch_size, 1].
|
||||||
|
scope: the scope for the operations performed in computing the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar `Tensor` representing the loss value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the shapes of logits, labels, and weight are incompatible, or
|
||||||
|
if `weight` is None.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
||||||
|
[logits, labels]):
|
||||||
|
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||||
|
weight = array_ops.squeeze(weight)
|
||||||
|
|
||||||
|
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
|
||||||
|
name="xentropy")
|
||||||
|
return _compute_weighted_loss(losses, weight)
|
||||||
|
|
||||||
|
|
||||||
def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
|
def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
|
||||||
"""Adds a Log Loss term to the training procedure.
|
"""Adds a Log Loss term to the training procedure.
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
|||||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||||
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||||
|
|
||||||
def testAllWrongAllMissing(self):
|
def testAllWrongAllWeightsMissing(self):
|
||||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
[0.0, 10.0, 0.0],
|
[0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -185,7 +185,7 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
|||||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||||
|
|
||||||
def testSomeMissing(self):
|
def testSomeWeightsMissing(self):
|
||||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
[0.0, 10.0, 0.0],
|
[0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -235,6 +235,216 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
|||||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testNoneWeightRaisesValueError(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2]])
|
||||||
|
with self.test_session():
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=None)
|
||||||
|
|
||||||
|
def testAllCorrectInt32Labels(self):
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2]], dtype=tf.int32)
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||||
|
|
||||||
|
def testAllCorrectInt64Labels(self):
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2]], dtype=tf.int64)
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||||
|
|
||||||
|
def testAllCorrectNonColumnLabels(self):
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([0, 1, 2])
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||||
|
|
||||||
|
def testAllWrongInt32Labels(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]], dtype=tf.int32)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||||
|
|
||||||
|
def testAllWrongInt64Labels(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]], dtype=tf.int64)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||||
|
|
||||||
|
def testAllWrongNonColumnLabels(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([2, 0, 1])
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
|
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||||
|
|
||||||
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = 2.3
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight)
|
||||||
|
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||||
|
|
||||||
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = 2.3
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, tf.constant(weight))
|
||||||
|
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||||
|
|
||||||
|
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = tf.constant([1.2, 3.4, 5.6], shape=[3])
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight)
|
||||||
|
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||||
|
|
||||||
|
def testNonZeroLossWithColumnWeights(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = tf.constant([[1.2], [3.4], [5.6]])
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight)
|
||||||
|
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||||
|
|
||||||
|
def testAllWrongAllWeightsMissing(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = tf.constant([0, 0, 0], shape=[3])
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight)
|
||||||
|
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||||
|
|
||||||
|
def testSomeWeightsMissing(self):
|
||||||
|
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = tf.constant([[2], [0], [1]])
|
||||||
|
weight = tf.constant([1.2, 0, 0], shape=[3])
|
||||||
|
with self.test_session():
|
||||||
|
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight)
|
||||||
|
self.assertAlmostEqual(loss.eval(), 12.0, 3)
|
||||||
|
|
||||||
|
def testMeasurementSpecificWeightsRaisesException(self):
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||||
|
[-100.0, 100.0, -100.0],
|
||||||
|
[-100.0, -100.0, 100.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2]])
|
||||||
|
weight = tf.constant([[3, 4, 5],
|
||||||
|
[2, 6, 0],
|
||||||
|
[8, 0, 1]])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=weight).eval()
|
||||||
|
|
||||||
|
def testInconsistentWeightSizeRaisesException(self):
|
||||||
|
"""The weight tensor has incorrect number of elements."""
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||||
|
[-100.0, 100.0, -100.0],
|
||||||
|
[-100.0, -100.0, 100.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2]])
|
||||||
|
weight = tf.constant([1.2, 3.4, 5.6, 7.8])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=weight).eval()
|
||||||
|
|
||||||
|
def testInconsistentLabelSizeRaisesException(self):
|
||||||
|
"""The label tensor has incorrect number of elements."""
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||||
|
[-100.0, 100.0, -100.0],
|
||||||
|
[-100.0, -100.0, 100.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2], [3]])
|
||||||
|
weight = tf.constant([1.2, 3.4, 5.6])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=weight).eval()
|
||||||
|
|
||||||
|
def testInconsistentWeightShapeRaisesException(self):
|
||||||
|
"""The weight tensor has incorrect shape."""
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
|
||||||
|
[-100.0, 100.0, -100.0, -100.0],
|
||||||
|
[-100.0, -100.0, 100.0, -100.0],
|
||||||
|
[-100.0, -100.0, -100.0, 100.0]])
|
||||||
|
labels = tf.constant([[0], [1], [2], [3]])
|
||||||
|
weight = tf.constant([[1.2, 3.4], [5.6, 7.8]])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=weight).eval()
|
||||||
|
|
||||||
|
def testInconsistentLabelShapeRaisesException(self):
|
||||||
|
"""The label tensor has incorrect shape."""
|
||||||
|
with self.test_session():
|
||||||
|
logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
|
||||||
|
[-100.0, 100.0, -100.0, -100.0],
|
||||||
|
[-100.0, -100.0, 100.0, -100.0],
|
||||||
|
[-100.0, -100.0, -100.0, 100.0]])
|
||||||
|
labels = tf.constant([[0, 1], [2, 3]])
|
||||||
|
weight = tf.constant([1.2, 3.4, 5.6, 7.8])
|
||||||
|
|
||||||
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
|
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||||
|
logits, labels, weight=weight).eval()
|
||||||
|
|
||||||
|
|
||||||
class SigmoidCrossEntropyLossTest(tf.test.TestCase):
|
class SigmoidCrossEntropyLossTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testAllCorrectSigmoid(self):
|
def testAllCorrectSigmoid(self):
|
||||||
|
@ -419,6 +419,7 @@ $(wildcard tensorflow/core/graph/*.cc) \
|
|||||||
$(wildcard tensorflow/core/lib/*/*.cc) \
|
$(wildcard tensorflow/core/lib/*/*.cc) \
|
||||||
$(wildcard tensorflow/core/platform/*.cc) \
|
$(wildcard tensorflow/core/platform/*.cc) \
|
||||||
$(wildcard tensorflow/core/platform/*/*.cc) \
|
$(wildcard tensorflow/core/platform/*/*.cc) \
|
||||||
|
$(wildcard tensorflow/core/platform/*/*/*.cc) \
|
||||||
$(wildcard tensorflow/core/util/*.cc) \
|
$(wildcard tensorflow/core/util/*.cc) \
|
||||||
$(wildcard tensorflow/core/util/*/*.cc)
|
$(wildcard tensorflow/core/util/*/*.cc)
|
||||||
CORE_CC_EXCLUDE_SRCS := \
|
CORE_CC_EXCLUDE_SRCS := \
|
||||||
|
@ -25,7 +25,7 @@ tensorflow/core/lib/random/simple_philox.cc
|
|||||||
tensorflow/core/lib/random/random.cc
|
tensorflow/core/lib/random/random.cc
|
||||||
tensorflow/core/lib/random/distribution_sampler.cc
|
tensorflow/core/lib/random/distribution_sampler.cc
|
||||||
tensorflow/core/lib/io/zlib_outputbuffer.cc
|
tensorflow/core/lib/io/zlib_outputbuffer.cc
|
||||||
tensorflow/core/lib/io/zlib_inputbuffer.cc
|
tensorflow/core/lib/io/zlib_inputstream.cc
|
||||||
tensorflow/core/lib/io/two_level_iterator.cc
|
tensorflow/core/lib/io/two_level_iterator.cc
|
||||||
tensorflow/core/lib/io/table_builder.cc
|
tensorflow/core/lib/io/table_builder.cc
|
||||||
tensorflow/core/lib/io/table.cc
|
tensorflow/core/lib/io/table.cc
|
||||||
|
@ -76,6 +76,7 @@ tensorflow/core/kernels/example_parsing_ops.cc
|
|||||||
tensorflow/core/kernels/dynamic_stitch_op.cc
|
tensorflow/core/kernels/dynamic_stitch_op.cc
|
||||||
tensorflow/core/kernels/dynamic_partition_op.cc
|
tensorflow/core/kernels/dynamic_partition_op.cc
|
||||||
tensorflow/core/kernels/dense_update_ops.cc
|
tensorflow/core/kernels/dense_update_ops.cc
|
||||||
|
tensorflow/core/kernels/deep_conv2d.cc
|
||||||
tensorflow/core/kernels/cwise_ops_common.cc
|
tensorflow/core/kernels/cwise_ops_common.cc
|
||||||
tensorflow/core/kernels/cwise_op_tanh.cc
|
tensorflow/core/kernels/cwise_op_tanh.cc
|
||||||
tensorflow/core/kernels/cwise_op_sub.cc
|
tensorflow/core/kernels/cwise_op_sub.cc
|
||||||
@ -100,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc
|
|||||||
tensorflow/core/kernels/cwise_op_add.cc
|
tensorflow/core/kernels/cwise_op_add.cc
|
||||||
tensorflow/core/kernels/ctc_decoder_ops.cc
|
tensorflow/core/kernels/ctc_decoder_ops.cc
|
||||||
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
||||||
|
tensorflow/core/kernels/conv_ops_fused.cc
|
||||||
tensorflow/core/kernels/conv_ops.cc
|
tensorflow/core/kernels/conv_ops.cc
|
||||||
tensorflow/core/kernels/conv_grad_ops.cc
|
tensorflow/core/kernels/conv_grad_ops.cc
|
||||||
tensorflow/core/kernels/control_flow_ops.cc
|
tensorflow/core/kernels/control_flow_ops.cc
|
||||||
|
@ -127,7 +127,6 @@ time.
|
|||||||
|
|
||||||
@@aggregate_metrics
|
@@aggregate_metrics
|
||||||
@@aggregate_metric_map
|
@@aggregate_metric_map
|
||||||
@@run_metric
|
|
||||||
|
|
||||||
## Set `Ops`
|
## Set `Ops`
|
||||||
|
|
||||||
@ -147,7 +146,6 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
|
|||||||
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
|
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
|
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
|
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import run_metric
|
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean
|
||||||
|
@ -213,7 +213,7 @@ void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
|
|||||||
result->clear();
|
result->clear();
|
||||||
auto input_flat = input_tensor.flat<T>();
|
auto input_flat = input_tensor.flat<T>();
|
||||||
const auto start = std::inner_product(
|
const auto start = std::inner_product(
|
||||||
group_indices.begin(), group_indices.end(), input_strides.begin(), 0);
|
group_indices.begin(), group_indices.end(), input_strides.begin(), 0L);
|
||||||
const TensorShape& input_shape = input_tensor.shape();
|
const TensorShape& input_shape = input_tensor.shape();
|
||||||
const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
|
const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
|
||||||
for (int64 i = start; i < end; ++i) {
|
for (int64 i = start; i < end; ++i) {
|
||||||
@ -273,7 +273,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
const auto group_key = group.group();
|
const auto group_key = group.group();
|
||||||
const auto output_index = std::inner_product(
|
const auto output_index = std::inner_product(
|
||||||
group_key.begin(), group_key.end(), output_strides.begin(), 0);
|
group_key.begin(), group_key.end(), output_strides.begin(), 0L);
|
||||||
out(output_index) = group_set.size();
|
out(output_index) = group_set.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -441,7 +441,7 @@ void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
|
|||||||
|
|
||||||
std::set<T> group_set;
|
std::set<T> group_set;
|
||||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||||
if (group_set.size() > 0) {
|
if (!group_set.empty()) {
|
||||||
group_sets[group_indices] = group_set;
|
group_sets[group_indices] = group_set;
|
||||||
const auto set_size = group_set.size();
|
const auto set_size = group_set.size();
|
||||||
if (set_size > max_set_size) {
|
if (set_size > max_set_size) {
|
||||||
@ -516,7 +516,7 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
|
|||||||
|
|
||||||
std::set<T> group_set;
|
std::set<T> group_set;
|
||||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||||
if (group_set.size() > 0) {
|
if (!group_set.empty()) {
|
||||||
group_sets[group_indices] = group_set;
|
group_sets[group_indices] = group_set;
|
||||||
const auto set_size = group_set.size();
|
const auto set_size = group_set.size();
|
||||||
if (set_size > max_set_size) {
|
if (set_size > max_set_size) {
|
||||||
@ -632,7 +632,7 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
|
|||||||
|
|
||||||
std::set<T> group_set;
|
std::set<T> group_set;
|
||||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||||
if (group_set.size() > 0) {
|
if (!group_set.empty()) {
|
||||||
group_sets[*group_indices] = group_set;
|
group_sets[*group_indices] = group_set;
|
||||||
const auto set_size = group_set.size();
|
const auto set_size = group_set.size();
|
||||||
if (set_size > max_set_size) {
|
if (set_size > max_set_size) {
|
||||||
|
@ -121,7 +121,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
|||||||
predictions = np.asarray([1, 2, 3])
|
predictions = np.asarray([1, 2, 3])
|
||||||
labels = np.asarray([1, 2])
|
labels = np.asarray([1, 2])
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
ValueError, "are not compatible",
|
ValueError, "must be equal",
|
||||||
tf.contrib.metrics.confusion_matrix, predictions, labels)
|
tf.contrib.metrics.confusion_matrix, predictions, labels)
|
||||||
|
|
||||||
def testOutputIsInt32(self):
|
def testOutputIsInt32(self):
|
||||||
|
@ -22,8 +22,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
|
|
||||||
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
||||||
@ -467,6 +465,8 @@ def streaming_accuracy(predictions, labels, weights=None,
|
|||||||
predictions, labels = metric_ops_util.remove_squeezable_dimensions(
|
predictions, labels = metric_ops_util.remove_squeezable_dimensions(
|
||||||
predictions, labels)
|
predictions, labels)
|
||||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||||
|
if labels.dtype != predictions.dtype:
|
||||||
|
predictions = math_ops.cast(predictions, labels.dtype)
|
||||||
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
|
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
|
||||||
return streaming_mean(is_correct, weights, metrics_collections,
|
return streaming_mean(is_correct, weights, metrics_collections,
|
||||||
updates_collections, name or 'accuracy')
|
updates_collections, name or 'accuracy')
|
||||||
@ -2126,37 +2126,4 @@ def aggregate_metric_map(names_to_tuples):
|
|||||||
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
|
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
|
||||||
|
|
||||||
|
|
||||||
def run_metric(metric, predictions, targets, weights=None):
|
|
||||||
"""Runs a single metric.
|
|
||||||
|
|
||||||
This function runs metric on given predictions and targets. weights will be
|
|
||||||
used if metric contains 'weights' in its argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metric: A function that evaluates targets given predictions.
|
|
||||||
predictions: A `Tensor` of arbitrary shape.
|
|
||||||
targets: A `Tensor` of the same shape as `predictions`.
|
|
||||||
weights: A set of weights that can be used in metric function to compute
|
|
||||||
weighted result.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
result: result returned by metric function.
|
|
||||||
"""
|
|
||||||
metric_args = []
|
|
||||||
if hasattr(metric, '__code__'):
|
|
||||||
# Regular function.
|
|
||||||
metric_args = inspect.getargspec(metric).args
|
|
||||||
elif hasattr(metric, 'func') and hasattr(metric, 'keywords'):
|
|
||||||
# Partial function.
|
|
||||||
for arg in inspect.getargspec(metric.func).args:
|
|
||||||
if metric.keywords and arg not in metric.keywords.keys():
|
|
||||||
metric_args.append(arg)
|
|
||||||
if 'weights' in metric_args:
|
|
||||||
result = metric(predictions, targets, weights=weights)
|
|
||||||
else:
|
|
||||||
result = metric(predictions, targets)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = make_all(__name__)
|
__all__ = make_all(__name__)
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -2851,37 +2850,5 @@ class AggregateMetricMapTest(tf.test.TestCase):
|
|||||||
self.assertEqual(4, names_to_values['m2'].eval())
|
self.assertEqual(4, names_to_values['m2'].eval())
|
||||||
|
|
||||||
|
|
||||||
class RunMetricTest(tf.test.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
tf.reset_default_graph()
|
|
||||||
|
|
||||||
def testRunMetric(self):
|
|
||||||
predictions = tf.constant([2, 4, 6, 8], shape=(1, 4), dtype=tf.float32)
|
|
||||||
labels = tf.constant([1, 3, 2, 3], shape=(1, 4), dtype=tf.float32)
|
|
||||||
weights = tf.constant([0, 1, 0, 1], shape=(1, 4))
|
|
||||||
|
|
||||||
error, update_op = metrics.run_metric(metrics.streaming_mean_squared_error,
|
|
||||||
predictions, labels, weights)
|
|
||||||
with self.test_session() as sess:
|
|
||||||
sess.run(tf.initialize_local_variables())
|
|
||||||
self.assertEqual(13, sess.run(update_op))
|
|
||||||
self.assertEqual(13, error.eval())
|
|
||||||
|
|
||||||
def testRunMetricsWithOutWeights(self):
|
|
||||||
predictions = tf.constant([2, 4, 6], shape=(1, 3), dtype=tf.float32)
|
|
||||||
labels = tf.constant([1, 3, 2], shape=(1, 3), dtype=tf.float32)
|
|
||||||
|
|
||||||
streaming_mean_squared_error_no_weight = partial(
|
|
||||||
metrics.streaming_mean_squared_error, weights=None)
|
|
||||||
|
|
||||||
error, update_op = metrics.run_metric(
|
|
||||||
streaming_mean_squared_error_no_weight, predictions, labels)
|
|
||||||
with self.test_session() as sess:
|
|
||||||
sess.run(tf.initialize_local_variables())
|
|
||||||
self.assertEqual(6, sess.run(update_op))
|
|
||||||
self.assertEqual(6, error.eval())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#ifdef USE_HEXAGON_LIBS
|
#ifdef USE_HEXAGON_LIBS
|
||||||
#include "tensorflow/core/platform/hexagon/gemm_wrapper.h"
|
#include "tensorflow/core/platform/hexagon/gemm_wrapper.h"
|
||||||
|
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -49,6 +50,30 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Shows some statistics of hexagon dsp using hexagon specific APIs
|
||||||
|
#ifdef USE_HEXAGON_LIBS
|
||||||
|
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
|
||||||
|
const uint64 overhead_shared_lib_start =
|
||||||
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
|
const int wrapper_version = hexagon_gemm_wrapper_GetWrapperVersion();
|
||||||
|
const uint64 overhead_shared_lib_end =
|
||||||
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
|
const uint64 overhead_hexagon_rpc_start =
|
||||||
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
|
const int hexagon_binary_version =
|
||||||
|
hexagon_gemm_wrapper_GetHexagonBinaryVersion();
|
||||||
|
const uint64 overhead_hexagon_rpc_end =
|
||||||
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
|
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
|
||||||
|
<< (overhead_shared_lib_end - overhead_shared_lib_start)
|
||||||
|
<< " cycles";
|
||||||
|
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
|
||||||
|
<< ") overhead is "
|
||||||
|
<< (overhead_hexagon_rpc_end - overhead_hexagon_rpc_start)
|
||||||
|
<< " cycles";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Runs two small matrices through the operator, and leaves all the parameters
|
// Runs two small matrices through the operator, and leaves all the parameters
|
||||||
// at their default values.
|
// at their default values.
|
||||||
// This test is a sample to execute matmul on hexagon.
|
// This test is a sample to execute matmul on hexagon.
|
||||||
|
@ -28,6 +28,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/rnn_cell_test.py"],
|
srcs = ["python/kernel_tests/rnn_cell_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":rnn_py",
|
":rnn_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -39,6 +40,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/lstm_ops_test.py"],
|
srcs = ["python/kernel_tests/lstm_ops_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":rnn_py",
|
":rnn_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -82,6 +84,7 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/gru_ops_test.py"],
|
srcs = ["python/kernel_tests/gru_ops_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":rnn_py",
|
":rnn_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
@ -11,6 +11,7 @@ py_library(
|
|||||||
name = "training_py",
|
name = "training_py",
|
||||||
srcs = [
|
srcs = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
|
"python/training/bucket_ops.py",
|
||||||
"python/training/sampling_ops.py",
|
"python/training/sampling_ops.py",
|
||||||
"python/training/sequence_queueing_state_saver.py",
|
"python/training/sequence_queueing_state_saver.py",
|
||||||
],
|
],
|
||||||
@ -67,6 +68,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "bucket_ops_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["python/training/bucket_ops_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":training_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -38,6 +38,17 @@ balanced.
|
|||||||
|
|
||||||
@@stratified_sample
|
@@stratified_sample
|
||||||
@@stratified_sample_unknown_dist
|
@@stratified_sample_unknown_dist
|
||||||
|
|
||||||
|
## Bucketing
|
||||||
|
|
||||||
|
Use ['bucket'](#bucket) or
|
||||||
|
['bucket_by_sequence_length'](#bucket_by_sequence_length) to stratify
|
||||||
|
minibatches into groups ("buckets"). Use `bucket_by_sequence_length`
|
||||||
|
with the argument `dynamic_pad=True` to receive minibatches of similarly
|
||||||
|
sized sequences for efficient training via `dynamic_rnn`.
|
||||||
|
|
||||||
|
@@bucket
|
||||||
|
@@bucket_by_sequence_length
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -45,6 +56,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import
|
# pylint: disable=unused-import,wildcard-import
|
||||||
|
from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||||
from tensorflow.python.util.all_util import make_all
|
from tensorflow.python.util.all_util import make_all
|
||||||
|
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Operations for bucketing data into groups.
|
||||||
|
|
||||||
|
The classes and functions in this module are used to queue up data into
|
||||||
|
buckets conditional on side information (e.g. sequence length).
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import data_flow_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.training import input as input_py
|
||||||
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
_as_original_type = input_py._as_original_type
|
||||||
|
_as_tensor_list = input_py._as_tensor_list
|
||||||
|
_deserialize_sparse_tensors = input_py._deserialize_sparse_tensors
|
||||||
|
_dtypes = input_py._dtypes
|
||||||
|
_serialize_sparse_tensors = input_py._serialize_sparse_tensors
|
||||||
|
_shapes = input_py._shapes
|
||||||
|
_which_queue = input_py._which_queue
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_bucket(tensor_list):
|
||||||
|
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
||||||
|
if not tensor_list:
|
||||||
|
raise ValueError("Expected at least one tensor in bucket().")
|
||||||
|
return tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
def bucket(tensors,
|
||||||
|
which_bucket,
|
||||||
|
batch_size,
|
||||||
|
num_buckets,
|
||||||
|
num_threads=1,
|
||||||
|
capacity=32,
|
||||||
|
shapes=None,
|
||||||
|
dynamic_pad=False,
|
||||||
|
allow_smaller_final_batch=False,
|
||||||
|
keep_input=None,
|
||||||
|
shared_name=None,
|
||||||
|
name=None):
|
||||||
|
"""Lazy bucketing of input tensors according to `which_bucket`.
|
||||||
|
|
||||||
|
The argument `tensors` can be a list or a dictionary of tensors.
|
||||||
|
The value returned by the function will be of the same type
|
||||||
|
as `tensors`.
|
||||||
|
|
||||||
|
The tensors entering this function are put into the bucket given by
|
||||||
|
`which_bucket`. Each bucket has its own queue. When a bucket contains
|
||||||
|
`batch_size` elements, this minibatch is pushed onto a top queue. The
|
||||||
|
tensors returned from this function are a the result of dequeueing the
|
||||||
|
next minibatch from this top queue.
|
||||||
|
|
||||||
|
This function is implemented using several queues. A `QueueRunner` for the
|
||||||
|
queues is added to the current `Graph`'s `QUEUE_RUNNER` collection.
|
||||||
|
|
||||||
|
As the returned tensors are the result of of a dequeue operation, evaluating
|
||||||
|
them will throw a `tf.errors.OutOfRangeError` when the input queue is
|
||||||
|
exhausted. If these tensors are feeding another input queue, its queue runner
|
||||||
|
will catch this exception, however, if they are used in your main thread
|
||||||
|
you are responsible for catching this yourself.
|
||||||
|
|
||||||
|
*N.B.:* If `dynamic_pad` is `False`, you must ensure that either
|
||||||
|
(i) the `shapes` argument is passed, or (ii) all of the tensors in
|
||||||
|
`tensors` must have fully-defined shapes. `ValueError` will be
|
||||||
|
raised if neither of these conditions holds.
|
||||||
|
|
||||||
|
If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
|
||||||
|
tensors is known, but individual dimensions may have shape `None`.
|
||||||
|
In this case, for each enqueue the dimensions with value `None`
|
||||||
|
may have a variable length; upon dequeue, the output tensors will be padded
|
||||||
|
on the right to the maximum shape of the tensors in the current minibatch.
|
||||||
|
For numbers, this padding takes value 0. For strings, this padding is
|
||||||
|
the empty string. See `PaddingFIFOQueue` for more info.
|
||||||
|
|
||||||
|
If `allow_smaller_final_batch` is `True`, a smaller batch value than
|
||||||
|
`batch_size` is returned when the queues are closed and there are not enough
|
||||||
|
elements to fill the batch, otherwise the pending elements are discarded.
|
||||||
|
In addition, all output tensors' static shapes, as accessed via the
|
||||||
|
`get_shape()` method will have a 0th `Dimension` value of `None`, and
|
||||||
|
operations that depend on fixed batch_size would fail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: The list or dictionary of tensors, representing a single element,
|
||||||
|
to bucket. Nested lists are not supported.
|
||||||
|
which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
|
||||||
|
batch_size: The new batch size pulled from the queue
|
||||||
|
(python int or int32 scalar).
|
||||||
|
num_buckets: A python integer, the number of buckets.
|
||||||
|
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||||
|
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||||
|
and also the maximum number of elements within each bucket.
|
||||||
|
shapes: (Optional) The shapes for each example. Defaults to the
|
||||||
|
inferred shapes for `tensors`.
|
||||||
|
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||||
|
The given dimensions are padded upon dequeue so that tensors within a
|
||||||
|
batch have the same shapes.
|
||||||
|
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||||
|
batches to be smaller if there are insufficient items left in the queues.
|
||||||
|
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||||
|
controls whether the input is added to the queue or not. If it evaluates
|
||||||
|
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||||
|
dropped. This tensor essentially acts as a filtering mechanism.
|
||||||
|
The default behavior is to assume `keep_input=True`.
|
||||||
|
shared_name: (Optional). If set, the queues will be shared under the given
|
||||||
|
name across multiple sessions.
|
||||||
|
name: (Optional) A name for the operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(bucket, outputs)` where `bucket` is
|
||||||
|
a `int32` scalar tensor and `outputs` is a list or
|
||||||
|
dictionary of batched outputs corresponding to elements of `tensors`.
|
||||||
|
Every step will receive a new bucket of outputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the `shapes` are not specified, and cannot be
|
||||||
|
inferred from the elements of `tensors`.
|
||||||
|
"""
|
||||||
|
tensor_list = _as_tensor_list(tensors)
|
||||||
|
with ops.name_scope(name, "bucket", tensor_list) as name:
|
||||||
|
tensor_list = _validate_bucket(tensor_list)
|
||||||
|
(tensor_list, sparse_info) = _serialize_sparse_tensors(
|
||||||
|
tensor_list, enqueue_many=False)
|
||||||
|
|
||||||
|
# Round-trip batch_size to a tensor, and possibly back
|
||||||
|
batch_size = ops.convert_to_tensor(
|
||||||
|
batch_size, dtype=dtypes.int32, name="batch_size")
|
||||||
|
static_batch_size = tensor_util.constant_value(batch_size)
|
||||||
|
batch_size = (
|
||||||
|
static_batch_size if static_batch_size is not None else batch_size)
|
||||||
|
|
||||||
|
types = _dtypes([tensor_list])
|
||||||
|
shapes = _shapes([tensor_list], shapes, enqueue_many=False)
|
||||||
|
|
||||||
|
which_bucket = ops.convert_to_tensor(
|
||||||
|
which_bucket, dtype=dtypes.int32, name="which_bucket")
|
||||||
|
|
||||||
|
queue_creator = _which_queue(dynamic_pad)
|
||||||
|
bucket_queues = []
|
||||||
|
for i in range(num_buckets):
|
||||||
|
shared_name_i = (
|
||||||
|
"%s_%d" % (shared_name, i) if shared_name is not None else None)
|
||||||
|
bucket_queues.append(
|
||||||
|
queue_creator(capacity=capacity,
|
||||||
|
dtypes=types,
|
||||||
|
shapes=shapes,
|
||||||
|
shared_name=shared_name_i, name="bucket_queue_%d" % i))
|
||||||
|
|
||||||
|
maybe_static_batch_size = (
|
||||||
|
None if allow_smaller_final_batch else static_batch_size)
|
||||||
|
|
||||||
|
bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s)
|
||||||
|
for s in bucket_queues[0].shapes]
|
||||||
|
# top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
|
||||||
|
# queues because if we use allow_smaller_final_batch, shapes will
|
||||||
|
# contain Nones in their first entry; as a result, a regular
|
||||||
|
# FIFOQueue would die when being passed shapes that are not fully defined.
|
||||||
|
top_queue = data_flow_ops.PaddingFIFOQueue(
|
||||||
|
capacity=capacity,
|
||||||
|
dtypes=[dtypes.int32] + types,
|
||||||
|
shapes=[tensor_shape.scalar()] + bucket_shapes,
|
||||||
|
shared_name=shared_name, name="top_queue")
|
||||||
|
|
||||||
|
def enqueue_which():
|
||||||
|
def enqueue_single(i):
|
||||||
|
return bucket_queues[i].enqueue(tensor_list)
|
||||||
|
enqueues = [
|
||||||
|
control_flow_ops.cond(
|
||||||
|
math_ops.equal(which_bucket, i),
|
||||||
|
functools.partial(enqueue_single, i),
|
||||||
|
control_flow_ops.no_op)
|
||||||
|
for i in range(num_buckets)]
|
||||||
|
return control_flow_ops.group(*enqueues, name="group_enqueues")
|
||||||
|
|
||||||
|
if keep_input is not None:
|
||||||
|
# TODO(ebrevdo): Expand keep_input param to core training
|
||||||
|
# methods, and pipe through to _serialize_sparse_tensors; so
|
||||||
|
# that expensive serialization is guarded by keep_input.
|
||||||
|
maybe_enqueue = control_flow_ops.cond(
|
||||||
|
keep_input,
|
||||||
|
enqueue_which,
|
||||||
|
control_flow_ops.no_op)
|
||||||
|
else:
|
||||||
|
maybe_enqueue = enqueue_which()
|
||||||
|
|
||||||
|
bucket_enqueue_ops = [maybe_enqueue] * num_threads
|
||||||
|
|
||||||
|
if allow_smaller_final_batch:
|
||||||
|
which_dequeue = lambda q: q.dequeue_up_to
|
||||||
|
else:
|
||||||
|
which_dequeue = lambda q: q.dequeue_many
|
||||||
|
|
||||||
|
enqueues_to_top = [
|
||||||
|
top_queue.enqueue(
|
||||||
|
[constant_op.constant(i)] +
|
||||||
|
which_dequeue(q)(batch_size, name="read_bucket_%d" % i),
|
||||||
|
name="enqueue_from_bucket_%d" % i)
|
||||||
|
for i, q in enumerate(bucket_queues)]
|
||||||
|
|
||||||
|
for i, q in enumerate(bucket_queues):
|
||||||
|
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||||
|
q, [enqueues_to_top[i]],
|
||||||
|
queue_closed_exception_types=(
|
||||||
|
errors.OutOfRangeError, errors.CancelledError)))
|
||||||
|
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||||
|
top_queue, bucket_enqueue_ops,
|
||||||
|
queue_closed_exception_types=(
|
||||||
|
errors.OutOfRangeError, errors.CancelledError)))
|
||||||
|
|
||||||
|
for q in bucket_queues:
|
||||||
|
logging_ops.scalar_summary(
|
||||||
|
"bucket/%s/size" % q.name,
|
||||||
|
math_ops.cast(top_queue.size(), dtypes.float32))
|
||||||
|
logging_ops.scalar_summary(
|
||||||
|
"bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity),
|
||||||
|
math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity))
|
||||||
|
|
||||||
|
dequeued = top_queue.dequeue(name="dequeue_top")
|
||||||
|
which_bucket_dequeued = dequeued[0]
|
||||||
|
dequeued = dequeued[1:]
|
||||||
|
dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
|
||||||
|
return (which_bucket_dequeued, _as_original_type(tensors, dequeued))
|
||||||
|
|
||||||
|
|
||||||
|
def bucket_by_sequence_length(input_length,
|
||||||
|
tensors,
|
||||||
|
batch_size,
|
||||||
|
bucket_boundaries,
|
||||||
|
num_threads=1,
|
||||||
|
capacity=32,
|
||||||
|
shapes=None,
|
||||||
|
dynamic_pad=False,
|
||||||
|
allow_smaller_final_batch=False,
|
||||||
|
keep_input=None,
|
||||||
|
shared_name=None,
|
||||||
|
name=None):
|
||||||
|
"""Lazy bucketing of inputs according to their length.
|
||||||
|
|
||||||
|
This method calls `tf.contrib.training.bucket` under the hood, after first
|
||||||
|
subdividing the bucket boundaries into separate buckets and identifying which
|
||||||
|
bucket the given `input_length` belongs to. See the documentation for
|
||||||
|
`which_bucket` for details of the other arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_length: `int32` scalar `Tensor`, the sequence length of tensors.
|
||||||
|
tensors: The list or dictionary of tensors, representing a single element,
|
||||||
|
to bucket. Nested lists are not supported.
|
||||||
|
batch_size: The new batch size pulled from the queue
|
||||||
|
(python int or int32 scalar).
|
||||||
|
bucket_boundaries: int list, increasing non-negative numbers.
|
||||||
|
The edges of the buckets to use when bucketing tensors. Two extra buckets
|
||||||
|
are created, one for `input_length < bucket_boundaries[0]` and
|
||||||
|
one for `input_length >= bucket_boundaries[-1]`.
|
||||||
|
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||||
|
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||||
|
and also the maximum number of elements within each bucket.
|
||||||
|
shapes: (Optional) The shapes for each example. Defaults to the
|
||||||
|
inferred shapes for `tensors`.
|
||||||
|
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||||
|
The given dimensions are padded upon dequeue so that tensors within a
|
||||||
|
batch have the same shapes.
|
||||||
|
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||||
|
batches to be smaller if there are insufficient items left in the queues.
|
||||||
|
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||||
|
controls whether the input is added to the queue or not. If it evaluates
|
||||||
|
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||||
|
dropped. This tensor essentially acts as a filtering mechanism.
|
||||||
|
The default behavior is to assume `keep_input=True`.
|
||||||
|
shared_name: (Optional). If set, the queues will be shared under the given
|
||||||
|
name across multiple sessions.
|
||||||
|
name: (Optional) A name for the operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(sequence_length, outputs)` where `sequence_length` is
|
||||||
|
a 1-D `Tensor` of size `batch_size` and `outputs` is a list or dictionary
|
||||||
|
of batched, bucketed, outputs corresponding to elements of `tensors`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `bucket_boundaries` is not a list of python integers.
|
||||||
|
ValueError: if `bucket_boundaries` is empty or contains non-increasing
|
||||||
|
values.
|
||||||
|
"""
|
||||||
|
tensor_list = _as_tensor_list(tensors)
|
||||||
|
if not isinstance(bucket_boundaries, (list, tuple)):
|
||||||
|
raise TypeError(
|
||||||
|
"bucket_boundaries must be a list or tuple, but received: %s"
|
||||||
|
% bucket_boundaries)
|
||||||
|
if not bucket_boundaries:
|
||||||
|
raise ValueError("bucket_boundaries must not be empty")
|
||||||
|
for (s, e) in zip(bucket_boundaries[:-1], bucket_boundaries[1:]):
|
||||||
|
if not isinstance(s, int) or not isinstance(e, int):
|
||||||
|
raise TypeError(
|
||||||
|
"bucket boundaries must be integers, but saw: %s and %s" % (s, e))
|
||||||
|
if s >= e:
|
||||||
|
raise ValueError(
|
||||||
|
"Buckets must contain sequential increasing lengths, but saw: "
|
||||||
|
"%d before %d" % (s, e))
|
||||||
|
|
||||||
|
with ops.name_scope(name, "bucket_by_sequence_length",
|
||||||
|
[input_length] + tensor_list) as name:
|
||||||
|
input_length = ops.convert_to_tensor(
|
||||||
|
input_length, dtype=dtypes.int32, name="input_length")
|
||||||
|
# Bucketing conditions are:
|
||||||
|
# l < b[0]
|
||||||
|
# b[0] <= l < b[1]
|
||||||
|
# b[1] <= l < b[2]
|
||||||
|
# ...
|
||||||
|
# b[N-2] <= l < b[N-1]
|
||||||
|
# b[N-1] <= l
|
||||||
|
# Equivalent to:
|
||||||
|
# [-inf, b[0], b[1], ..., b[N-1]] <= l < [b[0], b[1], ..., b[N-1], inf]
|
||||||
|
buckets_min = [np.iinfo(np.int32).min] + list(bucket_boundaries)
|
||||||
|
buckets_max = list(bucket_boundaries) + [np.iinfo(np.int32).max]
|
||||||
|
conditions_c = math_ops.logical_and(
|
||||||
|
math_ops.less_equal(buckets_min, input_length),
|
||||||
|
math_ops.less(input_length, buckets_max))
|
||||||
|
which_bucket = math_ops.reduce_min(array_ops.where(conditions_c))
|
||||||
|
which_bucket = math_ops.to_int32(which_bucket)
|
||||||
|
|
||||||
|
if shapes is not None:
|
||||||
|
shapes = [tensor_shape.scalar()] + shapes
|
||||||
|
|
||||||
|
_, dequeued = bucket(
|
||||||
|
tensors=[input_length] + tensor_list,
|
||||||
|
which_bucket=which_bucket,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_buckets=len(bucket_boundaries) + 1,
|
||||||
|
num_threads=num_threads,
|
||||||
|
capacity=capacity,
|
||||||
|
shapes=shapes,
|
||||||
|
dynamic_pad=dynamic_pad,
|
||||||
|
allow_smaller_final_batch=allow_smaller_final_batch,
|
||||||
|
keep_input=keep_input,
|
||||||
|
shared_name=shared_name)
|
||||||
|
|
||||||
|
return (dequeued[0], _as_original_type(tensors, dequeued[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"bucket",
|
||||||
|
"bucket_by_sequence_length"
|
||||||
|
]
|
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
@ -0,0 +1,356 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for tf.contrib.training.bucket."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def _which_bucket(bucket_edges, v):
|
||||||
|
"""Identify which bucket v falls into.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket_edges: int array, bucket edges
|
||||||
|
v: int scalar, index
|
||||||
|
Returns:
|
||||||
|
int scalar, the bucket.
|
||||||
|
If v < bucket_edges[0], return 0.
|
||||||
|
If bucket_edges[0] <= v < bucket_edges[1], return 1.
|
||||||
|
...
|
||||||
|
If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges).
|
||||||
|
If v >= bucket_edges[-1], return len(bucket_edges) + 1
|
||||||
|
"""
|
||||||
|
v = np.asarray(v)
|
||||||
|
full = [0] + bucket_edges
|
||||||
|
found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0]
|
||||||
|
if not found.size:
|
||||||
|
return len(full)
|
||||||
|
return found[0]
|
||||||
|
|
||||||
|
|
||||||
|
class BucketTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
self.scalar_int_feed = tf.placeholder(tf.int32, ())
|
||||||
|
self.unk_int64_feed = tf.placeholder(tf.int64, (None,))
|
||||||
|
self.vec3_str_feed = tf.placeholder(tf.string, (3,))
|
||||||
|
|
||||||
|
self._coord = tf.train.Coordinator()
|
||||||
|
# Make capacity very large so we can feed all the inputs in the
|
||||||
|
# main thread without blocking
|
||||||
|
input_queue = tf.PaddingFIFOQueue(
|
||||||
|
5000,
|
||||||
|
dtypes=[tf.int32, tf.int64, tf.string],
|
||||||
|
shapes=[(), (None,), (3,)])
|
||||||
|
|
||||||
|
self._input_enqueue_op = input_queue.enqueue(
|
||||||
|
(self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed))
|
||||||
|
self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue()
|
||||||
|
self._threads = None
|
||||||
|
self._close_op = input_queue.close()
|
||||||
|
self._sess = None
|
||||||
|
|
||||||
|
def enqueue_inputs(self, sess, feed_dict):
|
||||||
|
sess.run(self._input_enqueue_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def start_queue_runners(self, sess):
|
||||||
|
# Store session to be able to close inputs later
|
||||||
|
if self._sess is None:
|
||||||
|
self._sess = sess
|
||||||
|
self._threads = tf.train.start_queue_runners(coord=self._coord)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if self._sess is not None:
|
||||||
|
self._sess.run(self._close_op)
|
||||||
|
self._coord.request_stop()
|
||||||
|
self._coord.join(self._threads)
|
||||||
|
|
||||||
|
def testSingleBucket(self):
|
||||||
|
bucketed_dynamic = tf.contrib.training.bucket(
|
||||||
|
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||||
|
which_bucket=tf.constant(0),
|
||||||
|
num_buckets=2,
|
||||||
|
batch_size=32,
|
||||||
|
num_threads=10,
|
||||||
|
dynamic_pad=True)
|
||||||
|
# Check shape inference on bucketing outputs
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[32], [32, None], [32, 3]],
|
||||||
|
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||||
|
with self.test_session() as sess:
|
||||||
|
for v in range(32):
|
||||||
|
self.enqueue_inputs(
|
||||||
|
sess,
|
||||||
|
{self.scalar_int_feed: v,
|
||||||
|
self.unk_int64_feed: v * [v],
|
||||||
|
self.vec3_str_feed: 3 * [str(v)]})
|
||||||
|
self.start_queue_runners(sess)
|
||||||
|
|
||||||
|
# Get a single minibatch
|
||||||
|
bucketed_values = sess.run(bucketed_dynamic)
|
||||||
|
|
||||||
|
# (which_bucket, bucket_tensors).
|
||||||
|
self.assertEqual(2, len(bucketed_values))
|
||||||
|
|
||||||
|
# Count number of bucket_tensors.
|
||||||
|
self.assertEqual(3, len(bucketed_values[1]))
|
||||||
|
|
||||||
|
# Ensure bucket 0 was used for all minibatch entries.
|
||||||
|
self.assertAllEqual(0, bucketed_values[0])
|
||||||
|
|
||||||
|
expected_scalar_int = np.arange(32)
|
||||||
|
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||||
|
for i in range(32):
|
||||||
|
expected_unk_int64[i, :i] = i
|
||||||
|
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||||
|
|
||||||
|
# Must resort the output because num_threads > 1 leads to
|
||||||
|
# sometimes-inconsistent insertion order.
|
||||||
|
resort = np.argsort(bucketed_values[1][0])
|
||||||
|
self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort])
|
||||||
|
self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
|
||||||
|
self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
|
||||||
|
|
||||||
|
def testEvenOddBuckets(self):
|
||||||
|
which_bucket = (self.scalar_int % 2)
|
||||||
|
bucketed_dynamic = tf.contrib.training.bucket(
|
||||||
|
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||||
|
which_bucket=which_bucket,
|
||||||
|
num_buckets=2,
|
||||||
|
batch_size=32,
|
||||||
|
num_threads=10,
|
||||||
|
dynamic_pad=True)
|
||||||
|
# Check shape inference on bucketing outputs
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[32], [32, None], [32, 3]],
|
||||||
|
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||||
|
with self.test_session() as sess:
|
||||||
|
for v in range(64):
|
||||||
|
self.enqueue_inputs(
|
||||||
|
sess,
|
||||||
|
{self.scalar_int_feed: v,
|
||||||
|
self.unk_int64_feed: v * [v],
|
||||||
|
self.vec3_str_feed: 3 * [str(v)]})
|
||||||
|
self.start_queue_runners(sess)
|
||||||
|
|
||||||
|
# Get two minibatches (one containing even values, one containing odds)
|
||||||
|
bucketed_values_0 = sess.run(bucketed_dynamic)
|
||||||
|
bucketed_values_1 = sess.run(bucketed_dynamic)
|
||||||
|
|
||||||
|
# (which_bucket, bucket_tensors).
|
||||||
|
self.assertEqual(2, len(bucketed_values_0))
|
||||||
|
self.assertEqual(2, len(bucketed_values_1))
|
||||||
|
|
||||||
|
# Count number of bucket_tensors.
|
||||||
|
self.assertEqual(3, len(bucketed_values_0[1]))
|
||||||
|
self.assertEqual(3, len(bucketed_values_1[1]))
|
||||||
|
|
||||||
|
# Figure out which output has the even values (there's
|
||||||
|
# randomness due to the multithreaded nature of bucketing)
|
||||||
|
if bucketed_values_0[0] % 2 == 1:
|
||||||
|
bucketed_values_even, bucketed_values_odd = (
|
||||||
|
bucketed_values_1, bucketed_values_0)
|
||||||
|
else:
|
||||||
|
bucketed_values_even, bucketed_values_odd = (
|
||||||
|
bucketed_values_0, bucketed_values_1)
|
||||||
|
|
||||||
|
# Ensure bucket 0 was used for all minibatch entries.
|
||||||
|
self.assertAllEqual(0, bucketed_values_even[0])
|
||||||
|
self.assertAllEqual(1, bucketed_values_odd[0])
|
||||||
|
|
||||||
|
# Test the first bucket outputted, the events starting at 0
|
||||||
|
expected_scalar_int = np.arange(0, 32 * 2, 2)
|
||||||
|
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||||
|
for i in range(0, 32):
|
||||||
|
expected_unk_int64[i, :2*i] = 2*i
|
||||||
|
expected_vec3_str = np.vstack(
|
||||||
|
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||||
|
|
||||||
|
# Must resort the output because num_threads > 1 leads to
|
||||||
|
# sometimes-inconsistent insertion order.
|
||||||
|
resort = np.argsort(bucketed_values_even[1][0])
|
||||||
|
self.assertAllEqual(expected_scalar_int,
|
||||||
|
bucketed_values_even[1][0][resort])
|
||||||
|
self.assertAllEqual(expected_unk_int64,
|
||||||
|
bucketed_values_even[1][1][resort])
|
||||||
|
self.assertAllEqual(expected_vec3_str,
|
||||||
|
bucketed_values_even[1][2][resort])
|
||||||
|
|
||||||
|
# Test the second bucket outputted, the odds starting at 1
|
||||||
|
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2)
|
||||||
|
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||||
|
for i in range(0, 32):
|
||||||
|
expected_unk_int64[i, :2*i + 1] = 2*i + 1
|
||||||
|
expected_vec3_str = np.vstack(
|
||||||
|
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||||
|
|
||||||
|
# Must resort the output because num_threads > 1 leads to
|
||||||
|
# sometimes-inconsistent insertion order.
|
||||||
|
resort = np.argsort(bucketed_values_odd[1][0])
|
||||||
|
self.assertAllEqual(expected_scalar_int,
|
||||||
|
bucketed_values_odd[1][0][resort])
|
||||||
|
self.assertAllEqual(expected_unk_int64,
|
||||||
|
bucketed_values_odd[1][1][resort])
|
||||||
|
self.assertAllEqual(expected_vec3_str,
|
||||||
|
bucketed_values_odd[1][2][resort])
|
||||||
|
|
||||||
|
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||||
|
which_bucket = (self.scalar_int % 2)
|
||||||
|
keep_input = tf.equal(which_bucket, 0)
|
||||||
|
bucketed_dynamic = tf.contrib.training.bucket(
|
||||||
|
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||||
|
which_bucket=which_bucket,
|
||||||
|
num_buckets=2,
|
||||||
|
batch_size=32,
|
||||||
|
num_threads=10,
|
||||||
|
keep_input=keep_input,
|
||||||
|
dynamic_pad=True)
|
||||||
|
# Check shape inference on bucketing outputs
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[32], [32, None], [32, 3]],
|
||||||
|
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||||
|
with self.test_session() as sess:
|
||||||
|
for v in range(128):
|
||||||
|
self.enqueue_inputs(
|
||||||
|
sess,
|
||||||
|
{self.scalar_int_feed: v,
|
||||||
|
self.unk_int64_feed: v * [v],
|
||||||
|
self.vec3_str_feed: 3 * [str(v)]})
|
||||||
|
self.start_queue_runners(sess)
|
||||||
|
|
||||||
|
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||||
|
bucketed_values_even0 = sess.run(bucketed_dynamic)
|
||||||
|
bucketed_values_even1 = sess.run(bucketed_dynamic)
|
||||||
|
|
||||||
|
# Ensure that bucket 1 was completely filtered out
|
||||||
|
self.assertAllEqual(0, bucketed_values_even0[0])
|
||||||
|
self.assertAllEqual(0, bucketed_values_even1[0])
|
||||||
|
|
||||||
|
# Merge their output for sorting and comparison
|
||||||
|
bucketed_values_all_elem0 = np.concatenate(
|
||||||
|
(bucketed_values_even0[1][0],
|
||||||
|
bucketed_values_even1[1][0]))
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.arange(0, 128, 2), sorted(bucketed_values_all_elem0))
|
||||||
|
|
||||||
|
|
||||||
|
class BucketBySequenceLengthTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _testBucketBySequenceLength(self, allow_small_batch):
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
# All inputs must be identical lengths across tuple index.
|
||||||
|
# The input reader will get input_length from the first tuple
|
||||||
|
# entry.
|
||||||
|
data_len = 4
|
||||||
|
target_len = 3
|
||||||
|
input_pairs = [
|
||||||
|
(length,
|
||||||
|
([np.int64(length)] * data_len,
|
||||||
|
[str(length).encode("ascii")] * target_len))
|
||||||
|
for length in (1, 3, 4, 5, 6, 10)]
|
||||||
|
|
||||||
|
lengths = tf.placeholder(tf.int32, ())
|
||||||
|
data = tf.placeholder(tf.int64, (data_len,))
|
||||||
|
targets = tf.placeholder(tf.string, (target_len,))
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
bucket_boundaries = [3, 4, 5, 10]
|
||||||
|
|
||||||
|
# Make capacity very large so we can feed all the inputs in the
|
||||||
|
# main thread without blocking
|
||||||
|
input_queue = tf.FIFOQueue(
|
||||||
|
5000, (tf.int32, tf.int64, tf.string),
|
||||||
|
((), (data_len,), (target_len,)))
|
||||||
|
input_enqueue_op = input_queue.enqueue((lengths, data, targets))
|
||||||
|
lengths_t, data_t, targets_t = input_queue.dequeue()
|
||||||
|
close_input_op = input_queue.close()
|
||||||
|
|
||||||
|
(out_lengths_t, data_and_targets_t) = (
|
||||||
|
tf.contrib.training.bucket_by_sequence_length(
|
||||||
|
input_length=lengths_t,
|
||||||
|
tensors=[data_t, targets_t],
|
||||||
|
batch_size=batch_size,
|
||||||
|
bucket_boundaries=bucket_boundaries,
|
||||||
|
allow_smaller_final_batch=allow_small_batch,
|
||||||
|
num_threads=10))
|
||||||
|
|
||||||
|
expected_batch_size = None if allow_small_batch else batch_size
|
||||||
|
self.assertEqual(out_lengths_t.get_shape().as_list(),
|
||||||
|
[expected_batch_size])
|
||||||
|
self.assertEqual(data_and_targets_t[0].get_shape().as_list(),
|
||||||
|
[expected_batch_size, data_len])
|
||||||
|
self.assertEqual(data_and_targets_t[1].get_shape().as_list(),
|
||||||
|
[expected_batch_size, target_len])
|
||||||
|
|
||||||
|
def _read_test(sess):
|
||||||
|
for _ in range(50):
|
||||||
|
(out_lengths, (data, targets)) = sess.run(
|
||||||
|
(out_lengths_t, data_and_targets_t))
|
||||||
|
if allow_small_batch:
|
||||||
|
self.assertEqual(data_len, data.shape[1])
|
||||||
|
self.assertEqual(target_len, targets.shape[1])
|
||||||
|
self.assertGreaterEqual(batch_size, out_lengths.shape[0])
|
||||||
|
self.assertGreaterEqual(batch_size, data.shape[0])
|
||||||
|
self.assertGreaterEqual(batch_size, targets.shape[0])
|
||||||
|
else:
|
||||||
|
self.assertEqual((batch_size, data_len), data.shape)
|
||||||
|
self.assertEqual((batch_size, target_len), targets.shape)
|
||||||
|
self.assertEqual((batch_size,), out_lengths.shape)
|
||||||
|
for (lr, dr, tr) in zip(out_lengths, data, targets):
|
||||||
|
# Make sure length matches data (here it's the same value)
|
||||||
|
self.assertEqual(dr[0], lr)
|
||||||
|
# Make sure data & targets match
|
||||||
|
self.assertEqual(dr[0], int(tr[0].decode("ascii")))
|
||||||
|
# Make sure for each row, data came from the same bucket.
|
||||||
|
self.assertEqual(_which_bucket(bucket_boundaries, dr[0]),
|
||||||
|
_which_bucket(bucket_boundaries, dr[1]))
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
coord = tf.train.Coordinator()
|
||||||
|
|
||||||
|
# Feed the inputs, then close the input thread.
|
||||||
|
for _ in range(50 * batch_size + 100):
|
||||||
|
which = random.randint(0, len(input_pairs) - 1)
|
||||||
|
length, pair = input_pairs[which]
|
||||||
|
sess.run(input_enqueue_op, feed_dict={
|
||||||
|
lengths: length, data: pair[0], targets: pair[1]})
|
||||||
|
sess.run(close_input_op)
|
||||||
|
|
||||||
|
# Start the queue runners
|
||||||
|
threads = tf.train.start_queue_runners(coord=coord)
|
||||||
|
# Read off the top of the bucket and ensure correctness of output
|
||||||
|
_read_test(sess)
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
def testBucketBySequenceLength(self):
|
||||||
|
self._testBucketBySequenceLength(allow_small_batch=False)
|
||||||
|
|
||||||
|
def testBucketBySequenceLengthAllow(self):
|
||||||
|
self._testBucketBySequenceLength(allow_small_batch=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -171,7 +171,6 @@ cc_library(
|
|||||||
"platform/env.h",
|
"platform/env.h",
|
||||||
"platform/file_system.h",
|
"platform/file_system.h",
|
||||||
"platform/fingerprint.h",
|
"platform/fingerprint.h",
|
||||||
"platform/hexagon/profile_utils/cpu_utils.h",
|
|
||||||
"platform/init_main.h",
|
"platform/init_main.h",
|
||||||
"platform/logging.h",
|
"platform/logging.h",
|
||||||
"platform/macros.h",
|
"platform/macros.h",
|
||||||
@ -179,6 +178,7 @@ cc_library(
|
|||||||
"platform/net.h",
|
"platform/net.h",
|
||||||
"platform/mutex.h",
|
"platform/mutex.h",
|
||||||
"platform/notification.h",
|
"platform/notification.h",
|
||||||
|
"platform/profile_utils/cpu_utils.h",
|
||||||
"platform/protobuf.h", # TODO(josh11b): make internal
|
"platform/protobuf.h", # TODO(josh11b): make internal
|
||||||
"platform/regexp.h",
|
"platform/regexp.h",
|
||||||
"platform/strong_hash.h",
|
"platform/strong_hash.h",
|
||||||
@ -862,8 +862,8 @@ cc_library(
|
|||||||
"lib/**/*.cc",
|
"lib/**/*.cc",
|
||||||
"platform/*.h",
|
"platform/*.h",
|
||||||
"platform/*.cc",
|
"platform/*.cc",
|
||||||
"platform/hexagon/**/*.h",
|
"platform/profile_utils/**/*.h",
|
||||||
"platform/hexagon/**/*.cc",
|
"platform/profile_utils/**/*.cc",
|
||||||
] + tf_additional_lib_srcs(),
|
] + tf_additional_lib_srcs(),
|
||||||
exclude = [
|
exclude = [
|
||||||
"**/*test*",
|
"**/*test*",
|
||||||
@ -891,7 +891,7 @@ cc_library(
|
|||||||
"lib/io/snappy/snappy_inputbuffer.h",
|
"lib/io/snappy/snappy_inputbuffer.h",
|
||||||
"lib/io/snappy/snappy_outputbuffer.h",
|
"lib/io/snappy/snappy_outputbuffer.h",
|
||||||
"lib/io/zlib_compression_options.h",
|
"lib/io/zlib_compression_options.h",
|
||||||
"lib/io/zlib_inputbuffer.h",
|
"lib/io/zlib_inputstream.h",
|
||||||
"lib/io/zlib_outputbuffer.h",
|
"lib/io/zlib_outputbuffer.h",
|
||||||
"lib/jpeg/jpeg_handle.h",
|
"lib/jpeg/jpeg_handle.h",
|
||||||
"lib/png/png_io.h",
|
"lib/png/png_io.h",
|
||||||
@ -1348,11 +1348,11 @@ tf_cc_tests(
|
|||||||
"lib/strings/stringprintf_test.cc",
|
"lib/strings/stringprintf_test.cc",
|
||||||
"lib/wav/wav_io_test.cc",
|
"lib/wav/wav_io_test.cc",
|
||||||
"platform/fingerprint_test.cc",
|
"platform/fingerprint_test.cc",
|
||||||
"platform/hexagon/profile_utils/cpu_utils_test.cc",
|
|
||||||
"platform/integral_types_test.cc",
|
"platform/integral_types_test.cc",
|
||||||
"platform/logging_test.cc",
|
"platform/logging_test.cc",
|
||||||
"platform/net_test.cc",
|
"platform/net_test.cc",
|
||||||
"platform/port_test.cc",
|
"platform/port_test.cc",
|
||||||
|
"platform/profile_utils/cpu_utils_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":lib",
|
":lib",
|
||||||
|
@ -157,7 +157,7 @@ bool BFCAllocator::Extend(size_t rounded_bytes) {
|
|||||||
InsertFreeChunkIntoBin(h);
|
InsertFreeChunkIntoBin(h);
|
||||||
|
|
||||||
// Invoke visitors on newly allocated region.
|
// Invoke visitors on newly allocated region.
|
||||||
for (auto visitor : region_visitors_) {
|
for (const auto& visitor : region_visitors_) {
|
||||||
visitor(mem_addr, bytes);
|
visitor(mem_addr, bytes);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -279,7 +279,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
|||||||
edges_to_remove.push_back(out_edge);
|
edges_to_remove.push_back(out_edge);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
string node_name = n->name();
|
const string& node_name = n->name();
|
||||||
Node* constant_node;
|
Node* constant_node;
|
||||||
auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
|
auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
|
||||||
"__cf__", UniqueConstantId()),
|
"__cf__", UniqueConstantId()),
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -26,7 +27,9 @@ namespace {
|
|||||||
|
|
||||||
struct RegistrationInfo {
|
struct RegistrationInfo {
|
||||||
RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
|
RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
|
||||||
: sender_device_type(s), receiver_device_type(r), copy_function(cf) {}
|
: sender_device_type(std::move(s)),
|
||||||
|
receiver_device_type(r),
|
||||||
|
copy_function(cf) {}
|
||||||
DeviceType sender_device_type;
|
DeviceType sender_device_type;
|
||||||
DeviceType receiver_device_type;
|
DeviceType receiver_device_type;
|
||||||
CopyTensor::CopyFunction copy_function;
|
CopyTensor::CopyFunction copy_function;
|
||||||
|
@ -71,9 +71,9 @@ std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
|
|||||||
std::vector<DeviceType> result;
|
std::vector<DeviceType> result;
|
||||||
std::set<string> seen;
|
std::set<string> seen;
|
||||||
for (Device* d : devices_) {
|
for (Device* d : devices_) {
|
||||||
auto t = d->device_type();
|
const auto& t = d->device_type();
|
||||||
if (seen.insert(t).second) {
|
if (seen.insert(t).second) {
|
||||||
result.emplace_back(DeviceType(t));
|
result.emplace_back(t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::sort(result.begin(), result.end(), DeviceTypeComparator);
|
std::sort(result.begin(), result.end(), DeviceTypeComparator);
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
|
||||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -113,6 +112,77 @@ string GetRendezvousKey(const string& tensor_name,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
class DirectSessionFactory : public SessionFactory {
|
||||||
|
public:
|
||||||
|
DirectSessionFactory() {}
|
||||||
|
|
||||||
|
bool AcceptsOptions(const SessionOptions& options) override {
|
||||||
|
return options.target.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
Session* NewSession(const SessionOptions& options) override {
|
||||||
|
// Must do this before the CPU allocator is created.
|
||||||
|
if (options.config.graph_options().build_cost_model() > 0) {
|
||||||
|
EnableCPUAllocatorFullStats(true);
|
||||||
|
}
|
||||||
|
std::vector<Device*> devices;
|
||||||
|
Status s = DeviceFactory::AddDevices(
|
||||||
|
options, "/job:localhost/replica:0/task:0", &devices);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(ERROR) << s;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectSession* session =
|
||||||
|
new DirectSession(options, new DeviceMgr(devices), this);
|
||||||
|
{
|
||||||
|
mutex_lock l(sessions_lock_);
|
||||||
|
sessions_.push_back(session);
|
||||||
|
}
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Reset(const SessionOptions& options,
|
||||||
|
const std::vector<string>& containers) override {
|
||||||
|
std::vector<DirectSession*> sessions_to_reset;
|
||||||
|
{
|
||||||
|
mutex_lock l(sessions_lock_);
|
||||||
|
// We create a copy to ensure that we don't have a deadlock when
|
||||||
|
// session->Close calls the DirectSessionFactory.Deregister, which
|
||||||
|
// acquires sessions_lock_.
|
||||||
|
std::swap(sessions_to_reset, sessions_);
|
||||||
|
}
|
||||||
|
Status s;
|
||||||
|
for (auto session : sessions_to_reset) {
|
||||||
|
s.Update(session->Reset(containers));
|
||||||
|
}
|
||||||
|
// TODO(suharshs): Change the Reset behavior of all SessionFactories so that
|
||||||
|
// it doesn't close the sessions?
|
||||||
|
for (auto session : sessions_to_reset) {
|
||||||
|
s.Update(session->Close());
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Deregister(const DirectSession* session) {
|
||||||
|
mutex_lock l(sessions_lock_);
|
||||||
|
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
|
||||||
|
sessions_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutex sessions_lock_;
|
||||||
|
std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
|
||||||
|
};
|
||||||
|
|
||||||
|
class DirectSessionRegistrar {
|
||||||
|
public:
|
||||||
|
DirectSessionRegistrar() {
|
||||||
|
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static DirectSessionRegistrar registrar;
|
||||||
|
|
||||||
std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
|
std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
|
||||||
|
|
||||||
// NOTE: On Android with a single device, there is never
|
// NOTE: On Android with a single device, there is never
|
||||||
@ -146,10 +216,13 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
DirectSession::DirectSession(const SessionOptions& options,
|
DirectSession::DirectSession(const SessionOptions& options,
|
||||||
const DeviceMgr* device_mgr)
|
const DeviceMgr* device_mgr,
|
||||||
|
DirectSessionFactory* const factory)
|
||||||
: options_(options),
|
: options_(options),
|
||||||
device_mgr_(device_mgr),
|
device_mgr_(device_mgr),
|
||||||
|
factory_(factory),
|
||||||
cancellation_manager_(new CancellationManager()),
|
cancellation_manager_(new CancellationManager()),
|
||||||
|
closed_(false),
|
||||||
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
|
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
|
||||||
if (options_.config.session_inter_op_thread_pool_size() > 0) {
|
if (options_.config.session_inter_op_thread_pool_size() > 0) {
|
||||||
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
|
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
|
||||||
@ -194,6 +267,7 @@ DirectSession::DirectSession(const SessionOptions& options,
|
|||||||
}
|
}
|
||||||
|
|
||||||
DirectSession::~DirectSession() {
|
DirectSession::~DirectSession() {
|
||||||
|
if (!closed_) Close();
|
||||||
for (auto& it : partial_runs_) {
|
for (auto& it : partial_runs_) {
|
||||||
it.second.reset(nullptr);
|
it.second.reset(nullptr);
|
||||||
}
|
}
|
||||||
@ -237,6 +311,7 @@ Status DirectSession::Create(const GraphDef& graph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DirectSession::Extend(const GraphDef& graph) {
|
Status DirectSession::Extend(const GraphDef& graph) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||||
mutex_lock l(graph_def_lock_);
|
mutex_lock l(graph_def_lock_);
|
||||||
return ExtendLocked(graph);
|
return ExtendLocked(graph);
|
||||||
}
|
}
|
||||||
@ -267,6 +342,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
const std::vector<string>& target_nodes,
|
const std::vector<string>& target_nodes,
|
||||||
std::vector<Tensor>* outputs,
|
std::vector<Tensor>* outputs,
|
||||||
RunMetadata* run_metadata) {
|
RunMetadata* run_metadata) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||||
direct_session_runs->GetCell()->IncrementBy(1);
|
direct_session_runs->GetCell()->IncrementBy(1);
|
||||||
{
|
{
|
||||||
mutex_lock l(graph_def_lock_);
|
mutex_lock l(graph_def_lock_);
|
||||||
@ -412,6 +488,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
|||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
const std::vector<string>& target_nodes,
|
const std::vector<string>& target_nodes,
|
||||||
string* handle) {
|
string* handle) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||||
{
|
{
|
||||||
mutex_lock l(graph_def_lock_);
|
mutex_lock l(graph_def_lock_);
|
||||||
if (!graph_created_) {
|
if (!graph_created_) {
|
||||||
@ -487,6 +564,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
|||||||
Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
std::vector<Tensor>* outputs) {
|
std::vector<Tensor>* outputs) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||||
std::vector<string> parts = str_util::Split(handle, ';');
|
std::vector<string> parts = str_util::Split(handle, ';');
|
||||||
const string& key = parts[0];
|
const string& key = parts[0];
|
||||||
// Get the executors for this partial run.
|
// Get the executors for this partial run.
|
||||||
@ -1002,8 +1080,20 @@ Status DirectSession::CreateGraphs(
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
::tensorflow::Status DirectSession::Reset(
|
||||||
|
const std::vector<string>& containers) {
|
||||||
|
device_mgr_->ClearContainers(containers);
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
::tensorflow::Status DirectSession::Close() {
|
::tensorflow::Status DirectSession::Close() {
|
||||||
cancellation_manager_->StartCancel();
|
cancellation_manager_->StartCancel();
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (closed_) return ::tensorflow::Status::OK();
|
||||||
|
closed_ = true;
|
||||||
|
}
|
||||||
|
if (factory_ != nullptr) factory_->Deregister(this);
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1051,37 +1141,4 @@ void DirectSession::WaitForNotification(RunState* run_state,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class DirectSessionFactory : public SessionFactory {
|
|
||||||
public:
|
|
||||||
DirectSessionFactory() {}
|
|
||||||
|
|
||||||
bool AcceptsOptions(const SessionOptions& options) override {
|
|
||||||
return options.target.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
Session* NewSession(const SessionOptions& options) override {
|
|
||||||
// Must do this before the CPU allocator is created.
|
|
||||||
if (options.config.graph_options().build_cost_model() > 0) {
|
|
||||||
EnableCPUAllocatorFullStats(true);
|
|
||||||
}
|
|
||||||
std::vector<Device*> devices;
|
|
||||||
Status s = DeviceFactory::AddDevices(
|
|
||||||
options, "/job:localhost/replica:0/task:0", &devices);
|
|
||||||
if (!s.ok()) {
|
|
||||||
LOG(ERROR) << s;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new DirectSession(options, new DeviceMgr(devices));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class DirectSessionRegistrar {
|
|
||||||
public:
|
|
||||||
DirectSessionRegistrar() {
|
|
||||||
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
static DirectSessionRegistrar registrar;
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
|
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
|
||||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
@ -47,11 +48,18 @@ namespace tensorflow {
|
|||||||
class CostModel;
|
class CostModel;
|
||||||
class DebugGateway;
|
class DebugGateway;
|
||||||
class Device;
|
class Device;
|
||||||
|
class DirectSessionFactory;
|
||||||
|
|
||||||
class DirectSession : public Session {
|
class DirectSession : public Session {
|
||||||
public:
|
public:
|
||||||
|
typedef std::function<void(Session*)> CloseCallback;
|
||||||
|
|
||||||
// Takes ownership of 'device_mgr'.
|
// Takes ownership of 'device_mgr'.
|
||||||
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
|
// 'factory' is used to unregister the DirectSession with 'factory' when its
|
||||||
|
// closed. This ensures that Reset requests from the 'factory' don't get sent
|
||||||
|
// to sessions that are already closed.
|
||||||
|
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
|
||||||
|
DirectSessionFactory* factory);
|
||||||
~DirectSession() override;
|
~DirectSession() override;
|
||||||
|
|
||||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||||
@ -83,6 +91,10 @@ class DirectSession : public Session {
|
|||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
std::vector<Tensor>* outputs) override;
|
std::vector<Tensor>* outputs) override;
|
||||||
|
|
||||||
|
// Reset clears 'containers' from the device_mgr of the DirectSession.
|
||||||
|
// If 'containers' is empty, then Reset clears the default container.
|
||||||
|
::tensorflow::Status Reset(const std::vector<string>& containers);
|
||||||
|
|
||||||
::tensorflow::Status Close() override;
|
::tensorflow::Status Close() override;
|
||||||
|
|
||||||
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
|
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
|
||||||
@ -198,6 +210,12 @@ class DirectSession : public Session {
|
|||||||
// operation_timeout_in_ms is greater than 0.
|
// operation_timeout_in_ms is greater than 0.
|
||||||
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
|
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
|
||||||
|
|
||||||
|
::tensorflow::Status CheckNotClosed() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (closed_) return errors::Cancelled("Session has been closed.");
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
const SessionOptions options_;
|
const SessionOptions options_;
|
||||||
|
|
||||||
// Device structures.
|
// Device structures.
|
||||||
@ -232,10 +250,12 @@ class DirectSession : public Session {
|
|||||||
// This holds all the tensors that are currently alive in the session.
|
// This holds all the tensors that are currently alive in the session.
|
||||||
SessionState session_state_;
|
SessionState session_state_;
|
||||||
|
|
||||||
|
DirectSessionFactory* const factory_; // not owned
|
||||||
CancellationManager* cancellation_manager_;
|
CancellationManager* cancellation_manager_;
|
||||||
|
|
||||||
// Saves and restores device placements for stateful nodes.
|
// Saves and restores device placements for stateful nodes.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
|
||||||
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
|
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
|
||||||
// is true, such as "params" and "queue" nodes. Once placed these
|
// is true, such as "params" and "queue" nodes. Once placed these
|
||||||
// nodes can not be moved to a different device. Maps node names to
|
// nodes can not be moved to a different device. Maps node names to
|
||||||
@ -251,6 +271,9 @@ class DirectSession : public Session {
|
|||||||
// library; it copies and modifies the function library.
|
// library; it copies and modifies the function library.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||||
|
|
||||||
|
// true if the Session has been Closed.
|
||||||
|
bool closed_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// For generating unique names.
|
// For generating unique names.
|
||||||
int64 name_counter_ GUARDED_BY(mu_) = 0;
|
int64 name_counter_ GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
|
@ -397,6 +397,14 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
|||||||
ASSERT_EQ(2, outputs.size());
|
ASSERT_EQ(2, outputs.size());
|
||||||
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
|
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
|
||||||
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
|
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
|
||||||
|
|
||||||
|
// Feed [first_const, first_const]
|
||||||
|
s = session->Run(
|
||||||
|
{{first_const->name(), value_11}, {first_const->name(), value_22}},
|
||||||
|
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
||||||
|
&outputs);
|
||||||
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
|
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("Darth")
|
REGISTER_OP("Darth")
|
||||||
@ -970,5 +978,129 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||||
|
// Construct a graph with a variable and a single assign.
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
Tensor t(DT_FLOAT, TensorShape({}));
|
||||||
|
t.scalar<float>()() = {1.2};
|
||||||
|
Node* var_val = test::graph::Constant(&g, t);
|
||||||
|
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||||
|
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||||
|
GraphDef def;
|
||||||
|
test::graph::ToGraphDef(&g, &def);
|
||||||
|
|
||||||
|
SessionOptions options;
|
||||||
|
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||||
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
ASSERT_TRUE(session != nullptr);
|
||||||
|
TF_ASSERT_OK(session->Create(def));
|
||||||
|
|
||||||
|
// Assign a value to the var.
|
||||||
|
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||||
|
{var_assign->name()} /* target_nodes */, nullptr));
|
||||||
|
|
||||||
|
// Run a read on the variable to ensure that it works.
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
TF_ASSERT_OK(session->Run(
|
||||||
|
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||||
|
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||||
|
outputs.clear();
|
||||||
|
|
||||||
|
// Close the session.
|
||||||
|
session->Close();
|
||||||
|
|
||||||
|
// Run the read on the variable to get an error.
|
||||||
|
Status s = session->Run({} /* inputs */, {},
|
||||||
|
{var_assign->name()} /* target_nodes */, nullptr);
|
||||||
|
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||||
|
GraphDef def;
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
|
||||||
|
Tensor first_value(DT_FLOAT, TensorShape({}));
|
||||||
|
first_value.scalar<float>()() = 1.0;
|
||||||
|
Node* first_const = test::graph::Constant(&g, first_value);
|
||||||
|
Node* first_identity = test::graph::Identity(&g, first_const);
|
||||||
|
|
||||||
|
Tensor second_value(DT_FLOAT, TensorShape({}));
|
||||||
|
second_value.scalar<float>()() = 2.0;
|
||||||
|
Node* second_const = test::graph::Constant(&g, second_value);
|
||||||
|
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||||
|
|
||||||
|
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||||
|
Node* third_identity = test::graph::Identity(&g, third);
|
||||||
|
|
||||||
|
test::graph::ToGraphDef(&g, &def);
|
||||||
|
|
||||||
|
std::unique_ptr<Session> session(CreateSession());
|
||||||
|
ASSERT_TRUE(session != nullptr);
|
||||||
|
TF_ASSERT_OK(session->Create(def));
|
||||||
|
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
string handle;
|
||||||
|
Status s = session->PRunSetup(
|
||||||
|
{first_const->name(), second_const->name()},
|
||||||
|
{first_identity->name() + ":0", second_identity->name() + ":0",
|
||||||
|
third_identity->name() + ":0"},
|
||||||
|
{}, &handle);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
|
Tensor value_11(DT_FLOAT, TensorShape({}));
|
||||||
|
value_11.scalar<float>()() = 11.0;
|
||||||
|
Tensor value_22(DT_FLOAT, TensorShape({}));
|
||||||
|
value_22.scalar<float>()() = 22.0;
|
||||||
|
|
||||||
|
// Close the session.
|
||||||
|
session->Close();
|
||||||
|
|
||||||
|
// Feed first_const, fetch first_identity
|
||||||
|
s = session->PRun(handle, {{first_const->name(), value_11}},
|
||||||
|
{first_identity->name() + ":0"}, &outputs);
|
||||||
|
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DirectSessionTest, TestDirectSessionReset) {
|
||||||
|
// Construct a graph with a variable and a single assign.
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
Tensor t(DT_FLOAT, TensorShape({}));
|
||||||
|
t.scalar<float>()() = {1.2};
|
||||||
|
Node* var_val = test::graph::Constant(&g, t);
|
||||||
|
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||||
|
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||||
|
GraphDef def;
|
||||||
|
test::graph::ToGraphDef(&g, &def);
|
||||||
|
|
||||||
|
SessionOptions options;
|
||||||
|
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||||
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
ASSERT_TRUE(session != nullptr);
|
||||||
|
TF_ASSERT_OK(session->Create(def));
|
||||||
|
|
||||||
|
// Assign a value to the var.
|
||||||
|
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||||
|
{var_assign->name()} /* target_nodes */, nullptr));
|
||||||
|
|
||||||
|
// Run a read on the variable to ensure that it works.
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
TF_ASSERT_OK(session->Run(
|
||||||
|
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||||
|
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||||
|
outputs.clear();
|
||||||
|
|
||||||
|
// Reset the containers.
|
||||||
|
Reset(options, {});
|
||||||
|
|
||||||
|
// Run the read on the variable to get an error.
|
||||||
|
// TODO(suharshs): This test only works because we close the Session in Reset.
|
||||||
|
// If we change the behavior of Reset to not close the Session, this test will
|
||||||
|
// fail, since the Variable buffer is cached by var.
|
||||||
|
Status s = session->Run({} /* inputs */, {},
|
||||||
|
{var_assign->name()} /* target_nodes */, nullptr);
|
||||||
|
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -144,7 +144,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
|
|
||||||
void Init(const std::vector<FunctionDef>& flib) {
|
void Init(const std::vector<FunctionDef>& flib) {
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (auto fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
delete lib_def_;
|
delete lib_def_;
|
||||||
lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto);
|
lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto);
|
||||||
delete lib_;
|
delete lib_;
|
||||||
|
@ -95,7 +95,7 @@ void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream,
|
|||||||
FlushAccumulatedTensors();
|
FlushAccumulatedTensors();
|
||||||
}
|
}
|
||||||
accumulated_stream_ = stream;
|
accumulated_stream_ = stream;
|
||||||
for (auto t : tensors) {
|
for (const auto& t : tensors) {
|
||||||
// accumulated_tensors_ takes over ownership of the reference to "t"
|
// accumulated_tensors_ takes over ownership of the reference to "t"
|
||||||
accumulated_tensors_->push_back(t);
|
accumulated_tensors_->push_back(t);
|
||||||
accumulated_tensor_bytes_ += t.TotalBytes();
|
accumulated_tensor_bytes_ += t.TotalBytes();
|
||||||
|
@ -129,7 +129,7 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
|||||||
// Nodes should be assigned to streams by op type.
|
// Nodes should be assigned to streams by op type.
|
||||||
for (const auto& it : node_to_stream_id) {
|
for (const auto& it : node_to_stream_id) {
|
||||||
Node* n = g.FindNodeId(it.first);
|
Node* n = g.FindNodeId(it.first);
|
||||||
const string op = n->type_string();
|
const string& op = n->type_string();
|
||||||
const int stream = it.second;
|
const int stream = it.second;
|
||||||
if (op == "Const") {
|
if (op == "Const") {
|
||||||
EXPECT_EQ(stream, 90);
|
EXPECT_EQ(stream, 90);
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <sys/mman.h> // for munmap
|
#include <sys/mman.h> // for munmap
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -31,7 +32,7 @@ namespace tensorflow {
|
|||||||
PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
|
PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
|
||||||
SubAllocator* allocator,
|
SubAllocator* allocator,
|
||||||
RoundUpInterface* size_rounder, string name)
|
RoundUpInterface* size_rounder, string name)
|
||||||
: name_(name),
|
: name_(std::move(name)),
|
||||||
has_size_limit_(pool_size_limit > 0),
|
has_size_limit_(pool_size_limit > 0),
|
||||||
auto_resize_(auto_resize),
|
auto_resize_(auto_resize),
|
||||||
pool_size_limit_(pool_size_limit),
|
pool_size_limit_(pool_size_limit),
|
||||||
@ -125,7 +126,7 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
|||||||
return PrepareChunk(r, alignment, num_bytes);
|
return PrepareChunk(r, alignment, num_bytes);
|
||||||
} else {
|
} else {
|
||||||
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
|
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
|
||||||
for (auto v : alloc_visitors_) {
|
for (const auto& v : alloc_visitors_) {
|
||||||
v(ptr, num_bytes);
|
v(ptr, num_bytes);
|
||||||
}
|
}
|
||||||
return PrepareChunk(ptr, alignment, num_bytes);
|
return PrepareChunk(ptr, alignment, num_bytes);
|
||||||
@ -137,7 +138,7 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
|
|||||||
ChunkPrefix* cp = FindPrefix(ptr);
|
ChunkPrefix* cp = FindPrefix(ptr);
|
||||||
CHECK_LE((void*)cp, (void*)ptr);
|
CHECK_LE((void*)cp, (void*)ptr);
|
||||||
if (!has_size_limit_ && !auto_resize_) {
|
if (!has_size_limit_ && !auto_resize_) {
|
||||||
for (auto v : free_visitors_) {
|
for (const auto& v : free_visitors_) {
|
||||||
v(cp, cp->num_bytes);
|
v(cp, cp->num_bytes);
|
||||||
}
|
}
|
||||||
allocator_->Free(cp, cp->num_bytes);
|
allocator_->Free(cp, cp->num_bytes);
|
||||||
@ -160,7 +161,7 @@ void PoolAllocator::Clear() {
|
|||||||
mutex_lock lock(mutex_);
|
mutex_lock lock(mutex_);
|
||||||
for (auto iter : pool_) {
|
for (auto iter : pool_) {
|
||||||
PtrRecord* pr = iter.second;
|
PtrRecord* pr = iter.second;
|
||||||
for (auto v : free_visitors_) {
|
for (const auto& v : free_visitors_) {
|
||||||
v(pr->ptr, pr->num_bytes);
|
v(pr->ptr, pr->num_bytes);
|
||||||
}
|
}
|
||||||
allocator_->Free(pr->ptr, pr->num_bytes);
|
allocator_->Free(pr->ptr, pr->num_bytes);
|
||||||
@ -217,7 +218,7 @@ void PoolAllocator::EvictOne() {
|
|||||||
DCHECK(iter != pool_.end());
|
DCHECK(iter != pool_.end());
|
||||||
}
|
}
|
||||||
pool_.erase(iter);
|
pool_.erase(iter);
|
||||||
for (auto v : free_visitors_) {
|
for (const auto& v : free_visitors_) {
|
||||||
v(prec->ptr, prec->num_bytes);
|
v(prec->ptr, prec->num_bytes);
|
||||||
}
|
}
|
||||||
allocator_->Free(prec->ptr, prec->num_bytes);
|
allocator_->Free(prec->ptr, prec->num_bytes);
|
||||||
|
@ -181,12 +181,25 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
|
|||||||
// different numa_nodes. For now, just one.
|
// different numa_nodes. For now, just one.
|
||||||
numa_node = 0;
|
numa_node = 0;
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
|
|
||||||
|
// Find the first valid StreamExecutor to request CUDA host memory
|
||||||
|
// through, since any will work.
|
||||||
|
//
|
||||||
|
// This search isn't super clean, and it would be nice to use a
|
||||||
|
// better source of information about which executor to use. For
|
||||||
|
// example, process_state could maybe save the first stream executor
|
||||||
|
// it knows is valid.
|
||||||
|
gpu::StreamExecutor* se = nullptr;
|
||||||
|
for (size_t i = 0; i < gpu_allocators_.size(); ++i) {
|
||||||
|
if (gpu_allocators_[i] != nullptr) {
|
||||||
|
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_NE(nullptr, se);
|
||||||
|
|
||||||
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
|
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
|
||||||
// CUDAHost alloc the same across all gpus, so just get the
|
|
||||||
// executor for the first device.
|
|
||||||
gpu::Platform* gpu_platform = GPUMachineManager();
|
|
||||||
gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
|
|
||||||
CHECK(se);
|
|
||||||
Allocator* allocator = nullptr;
|
Allocator* allocator = nullptr;
|
||||||
static constexpr bool kCudaHostMemoryUseBFC = true;
|
static constexpr bool kCudaHostMemoryUseBFC = true;
|
||||||
if (kCudaHostMemoryUseBFC) {
|
if (kCudaHostMemoryUseBFC) {
|
||||||
|
@ -44,6 +44,7 @@ SimpleGraphExecutionState::SimpleGraphExecutionState(
|
|||||||
const SimpleGraphExecutionStateOptions& options)
|
const SimpleGraphExecutionStateOptions& options)
|
||||||
: device_set_(options.device_set),
|
: device_set_(options.device_set),
|
||||||
session_options_(options.session_options),
|
session_options_(options.session_options),
|
||||||
|
costs_(true /*is_global*/),
|
||||||
flib_def_(
|
flib_def_(
|
||||||
new FunctionLibraryDefinition(OpRegistry::Global(), func_def_lib)),
|
new FunctionLibraryDefinition(OpRegistry::Global(), func_def_lib)),
|
||||||
graph_(nullptr) {
|
graph_(nullptr) {
|
||||||
@ -53,6 +54,7 @@ SimpleGraphExecutionState::SimpleGraphExecutionState(
|
|||||||
|
|
||||||
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
|
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
node_name_to_cost_id_map_.clear();
|
||||||
delete graph_;
|
delete graph_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,6 +180,10 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
|||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertGraphDefToGraph(opts, original_graph_def_, new_graph.get()));
|
ConvertGraphDefToGraph(opts, original_graph_def_, new_graph.get()));
|
||||||
|
for (const Node* n : new_graph->nodes()) {
|
||||||
|
VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
|
||||||
|
node_name_to_cost_id_map_[n->name()] = n->cost_id();
|
||||||
|
}
|
||||||
if (session_options_ &&
|
if (session_options_ &&
|
||||||
session_options_->config.graph_options().place_pruned_graph()) {
|
session_options_->config.graph_options().place_pruned_graph()) {
|
||||||
// Rewrite the graph before placement.
|
// Rewrite the graph before placement.
|
||||||
@ -189,10 +195,15 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
|||||||
// Save stateful placements before placing.
|
// Save stateful placements before placing.
|
||||||
RestoreStatefulNodes(new_graph.get());
|
RestoreStatefulNodes(new_graph.get());
|
||||||
|
|
||||||
|
CostModel costs(true /*is_global*/);
|
||||||
|
costs_.InitFromGraph(*new_graph.get());
|
||||||
|
costs.MergeFromGlobal(costs_);
|
||||||
|
|
||||||
GraphOptimizationPassOptions optimization_options;
|
GraphOptimizationPassOptions optimization_options;
|
||||||
optimization_options.session_options = session_options_;
|
optimization_options.session_options = session_options_;
|
||||||
optimization_options.graph = &new_graph;
|
optimization_options.graph = &new_graph;
|
||||||
optimization_options.flib_def = flib_def_.get();
|
optimization_options.flib_def = flib_def_.get();
|
||||||
|
optimization_options.cost_model = &costs;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||||
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
|
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
|
||||||
@ -209,6 +220,31 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SimpleGraphExecutionState::UpdateCostsFromStats(const StepStats& ss) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
costs_.MergeFromStats(node_name_to_cost_id_map_, ss);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
costs->MergeFromGlobal(costs_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
|
||||||
|
NodeDef* out) {
|
||||||
|
NodeNameToCostIdMap::const_iterator iter =
|
||||||
|
node_name_to_cost_id_map_.find(name);
|
||||||
|
if (iter != node_name_to_cost_id_map_.end()) {
|
||||||
|
mutex_lock l(mu_); // could use reader lock
|
||||||
|
const Node* node = graph_->FindNodeId(iter->second);
|
||||||
|
if (node) {
|
||||||
|
*out = node->def();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors::NotFound("Node name: ", name);
|
||||||
|
}
|
||||||
|
|
||||||
Status SimpleGraphExecutionState::BuildGraph(
|
Status SimpleGraphExecutionState::BuildGraph(
|
||||||
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
|
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
|
||||||
VLOG(1) << "BuildGraph";
|
VLOG(1) << "BuildGraph";
|
||||||
@ -234,10 +270,14 @@ Status SimpleGraphExecutionState::BuildGraph(
|
|||||||
std::unique_ptr<FunctionLibraryDefinition> flib(
|
std::unique_ptr<FunctionLibraryDefinition> flib(
|
||||||
new FunctionLibraryDefinition(*flib_def_));
|
new FunctionLibraryDefinition(*flib_def_));
|
||||||
|
|
||||||
|
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
||||||
|
CostModel costs(true /*is_global*/);
|
||||||
|
costs.MergeFromGlobal(costs_);
|
||||||
GraphOptimizationPassOptions optimization_options;
|
GraphOptimizationPassOptions optimization_options;
|
||||||
optimization_options.session_options = session_options_;
|
optimization_options.session_options = session_options_;
|
||||||
optimization_options.graph = &ng;
|
optimization_options.graph = &ng;
|
||||||
optimization_options.flib_def = flib.get();
|
optimization_options.flib_def = flib.get();
|
||||||
|
optimization_options.cost_model = &costs;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||||
|
@ -119,6 +119,22 @@ class SimpleGraphExecutionState {
|
|||||||
// execution, e.g. a send, recv or feed node.
|
// execution, e.g. a send, recv or feed node.
|
||||||
Status GlobalNodeDefByName(const string& name, NodeDef* out);
|
Status GlobalNodeDefByName(const string& name, NodeDef* out);
|
||||||
|
|
||||||
|
// Sums execution statistics in "ss" into the CostModel.
|
||||||
|
void UpdateCostsFromStats(const StepStats& ss);
|
||||||
|
|
||||||
|
Microseconds TimeEstimate(const Node* n) {
|
||||||
|
mutex_lock l(mu_); // could use reader lock
|
||||||
|
return costs_.TimeEstimate(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Bytes SizeEstimate(const Node* n, int output_slot) {
|
||||||
|
mutex_lock l(mu_); // could use reader lock
|
||||||
|
return costs_.SizeEstimate(n, output_slot);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the cost model maintained by this graph_execution_state to 'costs'.
|
||||||
|
void MergeCostsFromGlobal(CostModel* costs);
|
||||||
|
|
||||||
// The graph returned by BuildGraph may contain only the pruned
|
// The graph returned by BuildGraph may contain only the pruned
|
||||||
// graph, whereas some clients may want access to the full graph.
|
// graph, whereas some clients may want access to the full graph.
|
||||||
const Graph* full_graph() {
|
const Graph* full_graph() {
|
||||||
@ -162,6 +178,11 @@ class SimpleGraphExecutionState {
|
|||||||
const DeviceSet* device_set_; // Not owned
|
const DeviceSet* device_set_; // Not owned
|
||||||
const SessionOptions* session_options_; // Not owned
|
const SessionOptions* session_options_; // Not owned
|
||||||
|
|
||||||
|
CostModel costs_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Map from name to Node for the full graph in placed_.
|
||||||
|
NodeNameToCostIdMap node_name_to_cost_id_map_;
|
||||||
|
|
||||||
// 'flib_def_' is initialized from the initial graph def's library,
|
// 'flib_def_' is initialized from the initial graph def's library,
|
||||||
// and may be updated by a graph optimization pass.
|
// and may be updated by a graph optimization pass.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||||
|
@ -42,7 +42,7 @@ std::vector<Device*> FilterSupportedDevices(
|
|||||||
const std::vector<Device*>& devices,
|
const std::vector<Device*>& devices,
|
||||||
const DeviceTypeVector& supported_device_types) {
|
const DeviceTypeVector& supported_device_types) {
|
||||||
std::vector<Device*> filtered_devices;
|
std::vector<Device*> filtered_devices;
|
||||||
for (DeviceType d : supported_device_types) {
|
for (const DeviceType& d : supported_device_types) {
|
||||||
for (Device* device : devices) {
|
for (Device* device : devices) {
|
||||||
if (DeviceType(device->attributes().device_type()) == d) {
|
if (DeviceType(device->attributes().device_type()) == d) {
|
||||||
filtered_devices.emplace_back(device);
|
filtered_devices.emplace_back(device);
|
||||||
@ -238,11 +238,15 @@ class ColocationGraph {
|
|||||||
// members_[old_root].supported_device_types.
|
// members_[old_root].supported_device_types.
|
||||||
MergeSupportedDevices(&members_[new_root].supported_device_types,
|
MergeSupportedDevices(&members_[new_root].supported_device_types,
|
||||||
members_[old_root].supported_device_types);
|
members_[old_root].supported_device_types);
|
||||||
if (members_[x_root].supported_device_types.size() == 0) {
|
if (members_[new_root].supported_device_types.size() == 0) {
|
||||||
|
string debug_info;
|
||||||
|
AddDebugInfo(x_root, &debug_info);
|
||||||
|
AddDebugInfo(y_root, &debug_info);
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Cannot colocate nodes '", x.name(), "' and '", y.name(),
|
"Cannot colocate nodes '", x.name(), "' and '", y.name(),
|
||||||
"' because no device type supports both of those nodes and the "
|
"' because no device type supports both of those nodes and the "
|
||||||
"other nodes colocated with them");
|
"other nodes colocated with them.",
|
||||||
|
debug_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -495,7 +499,7 @@ class ColocationGraph {
|
|||||||
"' does not match any device");
|
"' does not match any device");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (DeviceType d : member->supported_device_types) {
|
for (const DeviceType& d : member->supported_device_types) {
|
||||||
if (DeviceType(assigned_device->attributes().device_type()) == d) {
|
if (DeviceType(assigned_device->attributes().device_type()) == d) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -545,9 +549,9 @@ class ColocationGraph {
|
|||||||
target->clear();
|
target->clear();
|
||||||
|
|
||||||
// Iterate in priority order.
|
// Iterate in priority order.
|
||||||
for (DeviceType device_type : temp) {
|
for (const DeviceType& device_type : temp) {
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (DeviceType other_device_type : other) {
|
for (const DeviceType& other_device_type : other) {
|
||||||
if (device_type == other_device_type) {
|
if (device_type == other_device_type) {
|
||||||
found = true;
|
found = true;
|
||||||
break;
|
break;
|
||||||
|
@ -689,8 +689,9 @@ TEST_F(SimplePlacerTest,
|
|||||||
Status s = Place(&g);
|
Status s = Place(&g);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
StringPiece(s.error_message())
|
StringPiece(s.error_message())
|
||||||
.contains("Cannot assign a device to node 'var3': Node had no "
|
.contains("Cannot colocate nodes 'var3' and 'assign3' because no "
|
||||||
"OpKernel registered"));
|
"device type supports both of those nodes and the other "
|
||||||
|
"nodes colocated with them."));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
|
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
|
||||||
|
@ -54,9 +54,9 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
// A little bit of per-step state.
|
// A little bit of per-step state.
|
||||||
struct PerStepState {
|
struct PerStepState {
|
||||||
|
bool collect_timeline;
|
||||||
Microseconds start_micros = Microseconds(0);
|
Microseconds start_micros = Microseconds(0);
|
||||||
Microseconds end_micros = Microseconds(0);
|
Microseconds end_micros = Microseconds(0);
|
||||||
std::vector<StepStats> step_stats; // per partition
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A session encapsulates a graph computation (resource allocation,
|
// A session encapsulates a graph computation (resource allocation,
|
||||||
@ -522,6 +522,10 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
|||||||
|
|
||||||
// Prepares a number of calls to workers. One call per partition.
|
// Prepares a number of calls to workers. One call per partition.
|
||||||
ExecutorOpts exec_opts;
|
ExecutorOpts exec_opts;
|
||||||
|
if (pss->collect_timeline) {
|
||||||
|
exec_opts.set_record_timeline(true);
|
||||||
|
}
|
||||||
|
|
||||||
const int num = partitions_.size();
|
const int num = partitions_.size();
|
||||||
RunManyGraphs calls(num);
|
RunManyGraphs calls(num);
|
||||||
|
|
||||||
@ -597,8 +601,9 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (calls.get(i)->resp.has_step_stats()) {
|
if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) {
|
||||||
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
|
resp->mutable_metadata()->mutable_step_stats()->MergeFrom(
|
||||||
|
calls.get(i)->resp.step_stats());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -953,6 +958,8 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
|||||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||||
TRACEPRINTF("stepid %llu", step_id);
|
TRACEPRINTF("stepid %llu", step_id);
|
||||||
|
|
||||||
|
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
|
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
|
||||||
execution_state_.get(), &pss, opts,
|
execution_state_.get(), &pss, opts,
|
||||||
*req, resp, cancellation_manager_));
|
*req, resp, cancellation_manager_));
|
||||||
|
@ -162,6 +162,8 @@ Status GrpcSession::Run(const RunOptions& run_options,
|
|||||||
RunStepRequest req;
|
RunStepRequest req;
|
||||||
RunStepResponse resp;
|
RunStepResponse resp;
|
||||||
|
|
||||||
|
*req.mutable_options() = run_options;
|
||||||
|
|
||||||
for (const auto& it : inputs) {
|
for (const auto& it : inputs) {
|
||||||
Tensor input_tensor = it.second;
|
Tensor input_tensor = it.second;
|
||||||
auto feed = req.add_feed();
|
auto feed = req.add_feed();
|
||||||
@ -206,6 +208,10 @@ Status GrpcSession::Run(const RunOptions& run_options,
|
|||||||
(*outputs)[fetch_it->second] = output;
|
(*outputs)[fetch_it->second] = output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (run_metadata) {
|
||||||
|
run_metadata->Swap(resp.mutable_metadata());
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,6 +75,9 @@ static SessionOptions Options(const string& target, int placement_period) {
|
|||||||
// string.
|
// string.
|
||||||
options.target = strings::StrCat("grpc://", target);
|
options.target = strings::StrCat("grpc://", target);
|
||||||
options.config.set_placement_period(placement_period);
|
options.config.set_placement_period(placement_period);
|
||||||
|
options.config.mutable_graph_options()
|
||||||
|
->mutable_optimizer_options()
|
||||||
|
->set_opt_level(OptimizerOptions::L0);
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,9 +310,29 @@ TEST(GrpcSessionTest, MultiDevices) {
|
|||||||
TF_CHECK_OK(session->Create(def));
|
TF_CHECK_OK(session->Create(def));
|
||||||
{
|
{
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
|
RunOptions options;
|
||||||
|
options.set_trace_level(RunOptions::FULL_TRACE);
|
||||||
|
RunMetadata metadata;
|
||||||
|
TF_CHECK_OK(
|
||||||
|
session->Run(options, {}, {c->name()}, {}, &outputs, &metadata));
|
||||||
ASSERT_EQ(1, outputs.size());
|
ASSERT_EQ(1, outputs.size());
|
||||||
IsSingleFloatValue(outputs[0], 6.0 * kSize);
|
IsSingleFloatValue(outputs[0], 6.0 * kSize);
|
||||||
|
|
||||||
|
const StepStats& ss = metadata.step_stats();
|
||||||
|
// NOTE(mrry): We only assert that `c` is placed correctly,
|
||||||
|
// because the current placement algorithm will move its
|
||||||
|
// inputs to be colocated with it, when it is the sole
|
||||||
|
// consumer.
|
||||||
|
bool c_placed_correctly = false;
|
||||||
|
for (const auto& dev : ss.dev_stats()) {
|
||||||
|
for (const auto& node : dev.node_stats()) {
|
||||||
|
if (node.node_name() == c->name() &&
|
||||||
|
dev.device() == c_dev.name()) {
|
||||||
|
c_placed_correctly = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(c_placed_correctly);
|
||||||
}
|
}
|
||||||
TF_CHECK_OK(session->Close());
|
TF_CHECK_OK(session->Close());
|
||||||
}
|
}
|
||||||
|
@ -325,7 +325,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
StepStatsCollector* collector = nullptr;
|
StepStatsCollector* collector = nullptr;
|
||||||
// TODO(mrry): Collect results from a profiler if available.
|
if (call->request.exec_opts().record_timeline()) {
|
||||||
|
collector = new StepStatsCollector(call->response.mutable_step_stats());
|
||||||
|
// TODO(mrry,pbar): GPU tracing for distributed steps.
|
||||||
|
}
|
||||||
CancellationManager* cm = new CancellationManager;
|
CancellationManager* cm = new CancellationManager;
|
||||||
call->SetCancelCallback([this, cm, step_id]() {
|
call->SetCancelCallback([this, cm, step_id]() {
|
||||||
cm->StartCancel();
|
cm->StartCancel();
|
||||||
@ -340,7 +343,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
}
|
}
|
||||||
env_->graph_mgr->ExecuteAsync(
|
env_->graph_mgr->ExecuteAsync(
|
||||||
call->request.graph_handle(), step_id, call->request.exec_opts(),
|
call->request.graph_handle(), step_id, call->request.exec_opts(),
|
||||||
collector, cm, in, out, [this, call, cm, out, token](Status s) {
|
collector, cm, in, out,
|
||||||
|
[this, call, cm, out, token, collector](Status s) {
|
||||||
call->ClearCancelCallback();
|
call->ClearCancelCallback();
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -359,6 +363,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
val.AsProtoField(proto);
|
val.AsProtoField(proto);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
delete collector;
|
||||||
delete out;
|
delete out;
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
});
|
});
|
||||||
|
@ -39,6 +39,9 @@ message CostGraphDef {
|
|||||||
// Temporary memory used by this node.
|
// Temporary memory used by this node.
|
||||||
int64 temporary_memory_size = 6;
|
int64 temporary_memory_size = 6;
|
||||||
|
|
||||||
|
// Estimate of the computational cost of this node.
|
||||||
|
int64 compute_cost = 9;
|
||||||
|
|
||||||
// If true, the output is permanent: it can't be discarded, because this
|
// If true, the output is permanent: it can't be discarded, because this
|
||||||
// node is part of the "final output". Nodes may depend on final nodes.
|
// node is part of the "final output". Nodes may depend on final nodes.
|
||||||
bool is_final = 7;
|
bool is_final = 7;
|
||||||
|
@ -861,11 +861,11 @@ string DebugString(const GraphDef& instantiated_func_def) {
|
|||||||
|
|
||||||
string DebugStringWhole(const GraphDef& gdef) {
|
string DebugStringWhole(const GraphDef& gdef) {
|
||||||
string ret;
|
string ret;
|
||||||
for (auto fdef : gdef.library().function()) {
|
for (const auto& fdef : gdef.library().function()) {
|
||||||
strings::StrAppend(&ret, Print(fdef));
|
strings::StrAppend(&ret, Print(fdef));
|
||||||
}
|
}
|
||||||
strings::StrAppend(&ret, "\n");
|
strings::StrAppend(&ret, "\n");
|
||||||
for (auto ndef : gdef.node()) {
|
for (const auto& ndef : gdef.node()) {
|
||||||
strings::StrAppend(&ret, Print(ndef), "\n");
|
strings::StrAppend(&ret, Print(ndef), "\n");
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class CancellationManager;
|
class CancellationManager;
|
||||||
class Node;
|
|
||||||
class OpKernel;
|
class OpKernel;
|
||||||
class ResourceMgr;
|
class ResourceMgr;
|
||||||
|
|
||||||
|
@ -31,11 +31,11 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
|||||||
VersionDef* versions = g.mutable_versions();
|
VersionDef* versions = g.mutable_versions();
|
||||||
versions->set_producer(TF_GRAPH_DEF_VERSION);
|
versions->set_producer(TF_GRAPH_DEF_VERSION);
|
||||||
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
||||||
for (auto n : nodes) {
|
for (const auto& n : nodes) {
|
||||||
*(g.add_node()) = n;
|
*(g.add_node()) = n;
|
||||||
}
|
}
|
||||||
auto lib = g.mutable_library();
|
auto lib = g.mutable_library();
|
||||||
for (auto f : funcs) {
|
for (const auto& f : funcs) {
|
||||||
*(lib->add_function()) = f;
|
*(lib->add_function()) = f;
|
||||||
}
|
}
|
||||||
return g;
|
return g;
|
||||||
@ -49,7 +49,7 @@ NodeDef NDef(const string& name, const string& op,
|
|||||||
NodeDef n;
|
NodeDef n;
|
||||||
n.set_name(name);
|
n.set_name(name);
|
||||||
n.set_op(op);
|
n.set_op(op);
|
||||||
for (auto in : inputs) n.add_input(in);
|
for (const auto& in : inputs) n.add_input(in);
|
||||||
n.set_device(device);
|
n.set_device(device);
|
||||||
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
|
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
|
||||||
return n;
|
return n;
|
||||||
|
@ -60,7 +60,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
|
|||||||
|
|
||||||
Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
|
Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
|
||||||
const AttrValue& allowed_values(attr.allowed_values());
|
const AttrValue& allowed_values(attr.allowed_values());
|
||||||
for (auto allowed : allowed_values.list().s()) {
|
for (const auto& allowed : allowed_values.list().s()) {
|
||||||
if (str == allowed) {
|
if (str == allowed) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -381,7 +381,7 @@ class OpKernelBuilderTest : public ::testing::Test {
|
|||||||
DeviceTypeVector devices;
|
DeviceTypeVector devices;
|
||||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (DeviceType dt : devices) {
|
for (const DeviceType& dt : devices) {
|
||||||
if (dt == device_type) {
|
if (dt == device_type) {
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
@ -414,7 +414,7 @@ class OpKernelBuilderTest : public ::testing::Test {
|
|||||||
DeviceTypeVector devices;
|
DeviceTypeVector devices;
|
||||||
if (errors::IsNotFound(status)) {
|
if (errors::IsNotFound(status)) {
|
||||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||||
for (DeviceType dt : devices) {
|
for (const DeviceType& dt : devices) {
|
||||||
EXPECT_NE(dt, device_type);
|
EXPECT_NE(dt, device_type);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -200,7 +200,7 @@ class TensorShape {
|
|||||||
DataType data_type() const { return static_cast<DataType>(buf()[13]); }
|
DataType data_type() const { return static_cast<DataType>(buf()[13]); }
|
||||||
void set_data_type(DataType dt) {
|
void set_data_type(DataType dt) {
|
||||||
// We only have 8 bits available to store DataType, so make sure it fits
|
// We only have 8 bits available to store DataType, so make sure it fits
|
||||||
DCHECK_LT(static_cast<uint32>(dt), 256);
|
DCHECK_LT(static_cast<uint32>(dt), 256u);
|
||||||
buf()[13] = static_cast<uint8>(dt);
|
buf()[13] = static_cast<uint8>(dt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,9 @@ class TensorSlice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we have a full slice along dimension "d".
|
// If we have a full slice along dimension "d".
|
||||||
bool IsFullAt(int d) const { return lengths_[d] < 0; }
|
bool IsFullAt(int d) const {
|
||||||
|
return lengths_[d] == kFullExtent && starts_[d] == 0;
|
||||||
|
}
|
||||||
|
|
||||||
// If this is a full slice, i.e. IsFullAt(d) for every d.
|
// If this is a full slice, i.e. IsFullAt(d) for every d.
|
||||||
bool IsFull() const;
|
bool IsFull() const;
|
||||||
|
@ -273,8 +273,8 @@ TEST(TensorSliceTest, Deserialization) {
|
|||||||
TensorSlice ts3(proto3);
|
TensorSlice ts3(proto3);
|
||||||
|
|
||||||
// Both serializations should be interpreted the same.
|
// Both serializations should be interpreted the same.
|
||||||
EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString());
|
EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts2.DebugString());
|
||||||
EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString());
|
EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts3.DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TensorSliceTest, UpdateToCover) {
|
TEST(TensorSliceTest, UpdateToCover) {
|
||||||
|
@ -326,7 +326,7 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
|
|||||||
|
|
||||||
// A graph contains a bunch of constants.
|
// A graph contains a bunch of constants.
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
for (auto val : {a, b, c, d, d, c, b, a}) {
|
for (const auto& val : {a, b, c, d, d, c, b, a}) {
|
||||||
test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
|
test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
|
||||||
}
|
}
|
||||||
GraphDef gdef;
|
GraphDef gdef;
|
||||||
|
@ -74,7 +74,7 @@ inline bool IsGradientNode(const Graph* graph, const Node* node) {
|
|||||||
// Returns true if the root tensor op type is known, false otherwise.
|
// Returns true if the root tensor op type is known, false otherwise.
|
||||||
bool FindType(const Graph* graph, const Node* node, bool* signed_input,
|
bool FindType(const Graph* graph, const Node* node, bool* signed_input,
|
||||||
bool* range_given, float* input_min, float* input_max) {
|
bool* range_given, float* input_min, float* input_max) {
|
||||||
const string src_op = node->type_string();
|
const string& src_op = node->type_string();
|
||||||
if (src_op == "Const" || src_op == "Variable") {
|
if (src_op == "Const" || src_op == "Variable") {
|
||||||
*signed_input = true;
|
*signed_input = true;
|
||||||
*range_given = false;
|
*range_given = false;
|
||||||
|
@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||||
|
shape_inference::ShapeHandle shape) {
|
||||||
|
auto c = GetContext(node);
|
||||||
|
if (c == nullptr) {
|
||||||
|
return errors::Internal("Could not find context for ", node->name());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_port < 0 || output_port >= node->num_outputs()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"output_port '", output_port, "' is out of range, ", "node '",
|
||||||
|
node->name(), "' has ", node->num_outputs(), " outputs");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check compatibility
|
||||||
|
shape_inference::ShapeHandle existing_shape = c->output(output_port);
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused));
|
||||||
|
|
||||||
|
c->set_output(output_port, shape);
|
||||||
|
|
||||||
|
// TODO(vrv): Do we need to propagate the new shape through all
|
||||||
|
// consumers that change their outputs? At the moment, python
|
||||||
|
// does not do this, but this seems like a nice feature.
|
||||||
|
|
||||||
|
// TODO(vrv): We might need to keep track of the fact that the
|
||||||
|
// existing shape is invalidated, in case we need to propagate
|
||||||
|
// this information to remote workers.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
|
Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
|
||||||
const Tensor** input_tensor) const {
|
const Tensor** input_tensor) const {
|
||||||
*input_tensor = nullptr;
|
*input_tensor = nullptr;
|
||||||
|
@ -46,6 +46,14 @@ class ShapeRefiner {
|
|||||||
// - The shape inference function returns an error.
|
// - The shape inference function returns an error.
|
||||||
Status AddNode(const Node* node);
|
Status AddNode(const Node* node);
|
||||||
|
|
||||||
|
// Sets 'node's 'output_port' output to have shape 'shape'.
|
||||||
|
//
|
||||||
|
// Returns an error if 'node' was not previously added to this
|
||||||
|
// object, if 'output_port' is invalid, or if 'shape' is
|
||||||
|
// not compatible with the existing shape of the output.
|
||||||
|
Status SetShape(const Node* node, int output_port,
|
||||||
|
shape_inference::ShapeHandle shape);
|
||||||
|
|
||||||
// Returns the InferenceContext for 'node', if present.
|
// Returns the InferenceContext for 'node', if present.
|
||||||
shape_inference::InferenceContext* GetContext(const Node* node) const {
|
shape_inference::InferenceContext* GetContext(const Node* node) const {
|
||||||
auto it = node_to_context_.find(node);
|
auto it = node_to_context_.find(node);
|
||||||
|
@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) {
|
|||||||
ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
|
ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ShapeRefinerTest, SetShape) {
|
||||||
|
ShapeRefiner m;
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope();
|
||||||
|
auto a = ops::Const(root, {{1.0f}, {2.0f}});
|
||||||
|
|
||||||
|
TF_ASSERT_OK(m.AddNode(a.node()));
|
||||||
|
|
||||||
|
auto ic = m.GetContext(a.node());
|
||||||
|
ASSERT_NE(nullptr, ic);
|
||||||
|
shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
|
||||||
|
TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
|
||||||
|
EXPECT_SHAPE("[2,?]", m, a, 0);
|
||||||
|
|
||||||
|
// Out of range.
|
||||||
|
ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
|
||||||
|
ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
|
||||||
|
|
||||||
|
auto b = ops::Const(root, {{1.0f}, {2.0f}});
|
||||||
|
// Forget to add node first.
|
||||||
|
ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
|
||||||
|
|
||||||
|
// Set an incompatible shape (3 vs 2)
|
||||||
|
h = ic->MakeShape({3, ic->UnknownDim()});
|
||||||
|
ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ShapeRefinerTest, PropagateConstants) {
|
TEST(ShapeRefinerTest, PropagateConstants) {
|
||||||
// Reduction dimension is a variable, so we don't know its value.
|
// Reduction dimension is a variable, so we don't know its value.
|
||||||
// So the output shape value is unknown (though its rank is known).
|
// So the output shape value is unknown (though its rank is known).
|
||||||
|
@ -235,7 +235,15 @@ Status RewriteGraphForExecution(
|
|||||||
"Must specify at least one target to fetch or execute.");
|
"Must specify at least one target to fetch or execute.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_set<string> endpoints(fed_outputs.begin(), fed_outputs.end());
|
std::unordered_set<string> endpoints;
|
||||||
|
for (const string& endpoint_name : fed_outputs) {
|
||||||
|
auto result = endpoints.insert(endpoint_name);
|
||||||
|
if (!result.second) {
|
||||||
|
return errors::InvalidArgument("Endpoint \"", endpoint_name,
|
||||||
|
"\" fed more than once.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& fetch : fetch_outputs) {
|
for (const auto& fetch : fetch_outputs) {
|
||||||
if (endpoints.count(fetch) > 0) {
|
if (endpoints.count(fetch) > 0) {
|
||||||
return errors::InvalidArgument(fetch, " is both fed and fetched.");
|
return errors::InvalidArgument(fetch, " is both fed and fetched.");
|
||||||
|
@ -491,6 +491,27 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "conv_ops_test",
|
||||||
|
size = "small",
|
||||||
|
deps = [
|
||||||
|
":conv_ops",
|
||||||
|
":image",
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "example_parsing_ops_test",
|
name = "example_parsing_ops_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
@ -1325,6 +1346,7 @@ tf_kernel_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"conv_grad_ops.h",
|
"conv_grad_ops.h",
|
||||||
"deep_conv2d.h",
|
"deep_conv2d.h",
|
||||||
|
"gemm_functors.h",
|
||||||
"winograd_transform.h",
|
"winograd_transform.h",
|
||||||
],
|
],
|
||||||
prefix = "conv_ops",
|
prefix = "conv_ops",
|
||||||
@ -1332,6 +1354,7 @@ tf_kernel_library(
|
|||||||
":bounds_check",
|
":bounds_check",
|
||||||
":conv_2d",
|
":conv_2d",
|
||||||
":conv_3d",
|
":conv_3d",
|
||||||
|
":image_resizer_state",
|
||||||
":ops_util",
|
":ops_util",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -1958,6 +1981,7 @@ filegroup(
|
|||||||
"control_flow_ops.h",
|
"control_flow_ops.h",
|
||||||
"conv_2d.h",
|
"conv_2d.h",
|
||||||
"conv_ops.h",
|
"conv_ops.h",
|
||||||
|
"depthwise_conv_op.h",
|
||||||
"image_resizer_state.h",
|
"image_resizer_state.h",
|
||||||
"maxpooling_op.h",
|
"maxpooling_op.h",
|
||||||
"pad_op.h",
|
"pad_op.h",
|
||||||
@ -1998,6 +2022,7 @@ filegroup(
|
|||||||
"cwise_op_div.cc",
|
"cwise_op_div.cc",
|
||||||
"cwise_op_equal_to.cc",
|
"cwise_op_equal_to.cc",
|
||||||
"cwise_op_exp.cc",
|
"cwise_op_exp.cc",
|
||||||
|
"cwise_op_floor.cc",
|
||||||
"cwise_op_greater.cc",
|
"cwise_op_greater.cc",
|
||||||
"cwise_op_inverse.cc",
|
"cwise_op_inverse.cc",
|
||||||
"cwise_op_isfinite.cc",
|
"cwise_op_isfinite.cc",
|
||||||
@ -2017,6 +2042,7 @@ filegroup(
|
|||||||
"cwise_op_tanh.cc",
|
"cwise_op_tanh.cc",
|
||||||
"deep_conv2d.cc",
|
"deep_conv2d.cc",
|
||||||
"deep_conv2d.h",
|
"deep_conv2d.h",
|
||||||
|
"depthwise_conv_op.cc",
|
||||||
"dynamic_partition_op.cc",
|
"dynamic_partition_op.cc",
|
||||||
"winograd_transform.h",
|
"winograd_transform.h",
|
||||||
":android_extended_ops_headers",
|
":android_extended_ops_headers",
|
||||||
|
@ -67,7 +67,7 @@ class ArgOp : public OpKernel {
|
|||||||
input.shape().DebugString()));
|
input.shape().DebugString()));
|
||||||
|
|
||||||
TensorShape output_shape;
|
TensorShape output_shape;
|
||||||
TensorShape input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
for (int d = 0; d < input_dims - 1; ++d) {
|
for (int d = 0; d < input_dims - 1; ++d) {
|
||||||
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
|
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ class ExtractGlimpseOp : public OpKernel {
|
|||||||
// depth).
|
// depth).
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const TensorShape input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
const int32 num_dims = input_shape.dims();
|
const int32 num_dims = input_shape.dims();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_dims == 4,
|
context, num_dims == 4,
|
||||||
|
@ -190,7 +190,7 @@ class ComputeAccidentalHitsOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& in_true_candidates = context->input(0);
|
const Tensor& in_true_candidates = context->input(0);
|
||||||
TensorShape in_true_candidates_shape = in_true_candidates.shape();
|
const TensorShape& in_true_candidates_shape = in_true_candidates.shape();
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) &&
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) &&
|
||||||
in_true_candidates_shape.dim_size(1) == num_true_,
|
in_true_candidates_shape.dim_size(1) == num_true_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
|
@ -37,15 +37,12 @@ struct scalar_const_op {
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
|
||||||
|
|
||||||
template <typename Index>
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()() const {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(Index,
|
|
||||||
Index = 0) const {
|
|
||||||
return *val;
|
return *val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Index, typename PacketType = Packet>
|
template <typename PacketType = Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const {
|
||||||
packetOp(Index, Index = 0) const {
|
|
||||||
return internal::pset1<PacketType>(*val);
|
return internal::pset1<PacketType>(*val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_KERNELS_CONV_OPS_H_
|
#define TENSORFLOW_KERNELS_CONV_OPS_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
@ -38,6 +39,16 @@ class LaunchConv2DOp {
|
|||||||
TensorFormat data_format);
|
TensorFormat data_format);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Used to keep track of persistent memory buffers used within the op.
|
||||||
|
template <class T, size_t size>
|
||||||
|
struct Im2ColBufferResource : public ResourceBase {
|
||||||
|
// This mutex ensures that only a single operation at a time is able to use
|
||||||
|
// the buffer memory held by this resource.
|
||||||
|
mutex mu;
|
||||||
|
T data[size];
|
||||||
|
string DebugString() { return "Im2ColBufferResource"; }
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#ifdef GOOGLE_CUDA
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class LaunchConv2DOp<Eigen::GpuDevice, T> {
|
class LaunchConv2DOp<Eigen::GpuDevice, T> {
|
||||||
|
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
@ -0,0 +1,486 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Implements convolution operations with other kernels baked into the
|
||||||
|
// processing, to optimize latency and memory usage.
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||||
|
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||||
|
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||||
|
#include "tensorflow/core/util/padding.h"
|
||||||
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Combines bilinear resizing and mirror padding into the im2col transformation
|
||||||
|
// stage of convolution,
|
||||||
|
template <class T1, class T2, class T3, class TGemmFunctor>
|
||||||
|
class FusedResizeAndPadConvFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(OpKernelContext* context, const Tensor& input,
|
||||||
|
int input_batches, int resized_height, int resized_width,
|
||||||
|
int padded_height, int padded_width, int input_depth,
|
||||||
|
const T2* filter_data, int filter_height, int filter_width,
|
||||||
|
int filter_count, int stride_rows, int stride_cols,
|
||||||
|
Padding padding, T3* output_data, int output_height,
|
||||||
|
int output_width, const ImageResizerState& st,
|
||||||
|
int top_padding, int bottom_padding, int left_padding,
|
||||||
|
int right_padding, int pad_offset) {
|
||||||
|
if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
|
||||||
|
(input_depth <= 0)) {
|
||||||
|
LOG(WARNING) << "Conv2D was called with bad input dimensions: "
|
||||||
|
<< input_batches << ", " << padded_height << ", "
|
||||||
|
<< padded_width << ", " << input_depth;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
|
||||||
|
LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
|
||||||
|
<< filter_width << ", " << filter_height << ", "
|
||||||
|
<< filter_count;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((output_width <= 0) || (output_height <= 0)) {
|
||||||
|
LOG(WARNING) << "Conv2D was called with bad output width or height: "
|
||||||
|
<< output_width << ", " << output_height;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// These calculations define how the patches will be positioned within the
|
||||||
|
// input image. The actual definitions are quite complex, and rely on the
|
||||||
|
// previously-calculated output size.
|
||||||
|
int filter_left_offset;
|
||||||
|
int filter_top_offset;
|
||||||
|
if (padding == VALID) {
|
||||||
|
filter_left_offset =
|
||||||
|
((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
|
||||||
|
2;
|
||||||
|
filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
|
||||||
|
padded_height + 1) /
|
||||||
|
2;
|
||||||
|
} else {
|
||||||
|
filter_left_offset =
|
||||||
|
((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
|
||||||
|
filter_top_offset =
|
||||||
|
((output_height - 1) * stride_rows + filter_height - padded_height) /
|
||||||
|
2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The im2col buffer has # of patches rows, and # of filters cols.
|
||||||
|
// It's laid out like this, in row major order in memory:
|
||||||
|
// < filter value count >
|
||||||
|
// ^ +---------------------+
|
||||||
|
// patch | |
|
||||||
|
// count | |
|
||||||
|
// v +---------------------+
|
||||||
|
// Each patch row contains a filter_width x filter_height patch of the
|
||||||
|
// input, with the depth channel as the most contiguous in memory, followed
|
||||||
|
// by the width, then the height. This is the standard memory order in the
|
||||||
|
// image world if it helps to visualize it.
|
||||||
|
const int filter_value_count = filter_width * filter_height * input_depth;
|
||||||
|
|
||||||
|
// We don't want to allocate a buffer to hold all the patches if the size is
|
||||||
|
// going to be extremely large, so break it into chunks if it's bigger than
|
||||||
|
// a limit. Each chunk will be processed serially, so we can refill the
|
||||||
|
// buffer for the next chunk and reuse it, keeping maximum memory size down.
|
||||||
|
// In this case, we've picked 16 megabytes as a reasonable limit.
|
||||||
|
const size_t max_chunk_size = (16 * 1024 * 1024);
|
||||||
|
OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size,
|
||||||
|
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||||
|
const size_t patches_per_chunk =
|
||||||
|
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||||
|
// Because memory allocation is very expensive on mobile platforms, try to
|
||||||
|
// allocate a persistent buffer that will be kept around between calls. We
|
||||||
|
// use TensorFlow's resource management to ensure that the memory will be
|
||||||
|
// released when the session is over.
|
||||||
|
Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource;
|
||||||
|
std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator =
|
||||||
|
[](Im2ColBufferResource<T1, max_chunk_size>** resource) {
|
||||||
|
*resource = new Im2ColBufferResource<T1, max_chunk_size>();
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
|
||||||
|
"Conv2d", "im2col_buffer",
|
||||||
|
&im2col_buffer_resource, creator));
|
||||||
|
// This means that multiple ops can't be run simultaneously on different
|
||||||
|
// threads, because we have a single shared resource. The platforms this is
|
||||||
|
// aimed at have intra-op parallelism as their focus though, so it shouldn't
|
||||||
|
// be an issue.
|
||||||
|
mutex_lock lock_buffer(im2col_buffer_resource->mu);
|
||||||
|
core::ScopedUnref unref_buffer(im2col_buffer_resource);
|
||||||
|
T1* im2col_buffer = im2col_buffer_resource->data;
|
||||||
|
|
||||||
|
typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>();
|
||||||
|
|
||||||
|
for (int batch = 0; batch < input_batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
|
||||||
|
const int patch_index = (batch * output_width * output_height) +
|
||||||
|
(out_y * output_width) + out_x;
|
||||||
|
const int patch_index_within_chunk = patch_index % patches_per_chunk;
|
||||||
|
T1* im2col_patch_start =
|
||||||
|
im2col_buffer + (patch_index_within_chunk * filter_value_count);
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
const int conv_in_y = in_y_origin + filter_y;
|
||||||
|
float in_y = (conv_in_y - top_padding);
|
||||||
|
if (in_y < 0) {
|
||||||
|
in_y = -(in_y + 1.0f - pad_offset);
|
||||||
|
} else if (in_y >= resized_height) {
|
||||||
|
in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
|
||||||
|
}
|
||||||
|
in_y *= st.height_scale;
|
||||||
|
const int64 top_y_index = static_cast<int64>(std::floor(in_y));
|
||||||
|
const int64 bottom_y_index = std::min(
|
||||||
|
static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
|
||||||
|
const T1 y_lerp = in_y - top_y_index;
|
||||||
|
T1* im2col_row_start =
|
||||||
|
im2col_patch_start + (filter_y * filter_width * input_depth);
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
const int conv_in_x = in_x_origin + filter_x;
|
||||||
|
float in_x = (conv_in_x - left_padding);
|
||||||
|
if (in_x < 0) {
|
||||||
|
in_x = -(in_x + 1.0f - pad_offset);
|
||||||
|
} else if (in_x >= resized_width) {
|
||||||
|
in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
|
||||||
|
}
|
||||||
|
in_x *= st.width_scale;
|
||||||
|
const int64 left_x_index = static_cast<int64>(std::floor(in_x));
|
||||||
|
const int64 right_x_index = std::min(
|
||||||
|
static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
|
||||||
|
const T1 x_lerp = in_x - left_x_index;
|
||||||
|
T1* im2col_row_pixel =
|
||||||
|
im2col_row_start + (filter_x * input_depth);
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
T1 in_value;
|
||||||
|
if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
|
||||||
|
(conv_in_y >= 0) && (conv_in_y < padded_height)) {
|
||||||
|
const T1 top_left(
|
||||||
|
input_data(batch, top_y_index, left_x_index, in_channel));
|
||||||
|
const T1 top_right(input_data(batch, top_y_index,
|
||||||
|
right_x_index, in_channel));
|
||||||
|
const T1 bottom_left(input_data(batch, bottom_y_index,
|
||||||
|
left_x_index, in_channel));
|
||||||
|
const T1 bottom_right(input_data(batch, bottom_y_index,
|
||||||
|
right_x_index, in_channel));
|
||||||
|
const T1 top = top_left + (top_right - top_left) * x_lerp;
|
||||||
|
const T1 bottom =
|
||||||
|
bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||||
|
in_value = top + (bottom - top) * y_lerp;
|
||||||
|
} else {
|
||||||
|
in_value = T1(0);
|
||||||
|
}
|
||||||
|
im2col_row_pixel[in_channel] = in_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const bool is_last_in_chunk =
|
||||||
|
(patch_index_within_chunk == (patches_per_chunk - 1));
|
||||||
|
const bool is_last_overall =
|
||||||
|
((batch == (input_batches - 1)) &&
|
||||||
|
(out_y == (output_height - 1)) && (out_x == (output_width - 1)));
|
||||||
|
if (is_last_in_chunk || is_last_overall) {
|
||||||
|
// Now we've assembled a set of image patches into a matrix, apply a
|
||||||
|
// GEMM matrix multiply of the patches as rows, times the filter
|
||||||
|
// weights in columns, to get partial results in the output matrix.
|
||||||
|
const int how_many_patches = patch_index_within_chunk + 1;
|
||||||
|
const int m = how_many_patches;
|
||||||
|
const int n = filter_count;
|
||||||
|
const int k = filter_value_count;
|
||||||
|
const int lda = filter_value_count;
|
||||||
|
const int ldb = filter_count;
|
||||||
|
const int ldc = filter_count;
|
||||||
|
const size_t start_patch_index =
|
||||||
|
patch_index - (how_many_patches - 1);
|
||||||
|
T3* chunk_output_data =
|
||||||
|
output_data + (start_patch_index * filter_count);
|
||||||
|
TGemmFunctor gemm_functor;
|
||||||
|
gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb,
|
||||||
|
chunk_output_data, ldc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Implements a version of convolution with bilinear resizing and mirror padding
|
||||||
|
// included.
|
||||||
|
template <class T, class TConvFunctor>
|
||||||
|
class FusedResizeConv2DUsingGemmOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("resize_align_corners", &align_corners_));
|
||||||
|
MirrorPadMode mode;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
|
||||||
|
|
||||||
|
switch (mode) {
|
||||||
|
case MirrorPadMode::SYMMETRIC: {
|
||||||
|
offset_ = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case MirrorPadMode::REFLECT: {
|
||||||
|
offset_ = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
OP_REQUIRES(context, false,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"mode must be either REFLECT or SYMMETRIC."));
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||||
|
OP_REQUIRES(context, strides_.size() == 4,
|
||||||
|
errors::InvalidArgument("Sliding window strides field must "
|
||||||
|
"specify 4 dimensions"));
|
||||||
|
const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
|
||||||
|
const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, stride_n == 1 && stride_c == 1,
|
||||||
|
errors::InvalidArgument("Current implementation does not yet support "
|
||||||
|
"strides in the batch and depth dimensions."));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Input tensor is of the following dimensions:
|
||||||
|
// [ batch, in_rows, in_cols, in_depth ]
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
OP_REQUIRES(context, (input.shape().num_elements() > 0),
|
||||||
|
errors::InvalidArgument("Input tensor can't be empty"));
|
||||||
|
|
||||||
|
ImageResizerState st(align_corners_);
|
||||||
|
st.ValidateAndCalculateOutputSize(context, input);
|
||||||
|
if (!context->status().ok()) return;
|
||||||
|
const TensorShape resized_shape(
|
||||||
|
{input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
|
||||||
|
|
||||||
|
const Tensor& paddings = context->input(2);
|
||||||
|
|
||||||
|
const int dims = resized_shape.dims();
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, TensorShapeUtils::IsMatrix(paddings.shape()) &&
|
||||||
|
paddings.dim_size(1) == 2,
|
||||||
|
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||||
|
paddings.shape().DebugString()));
|
||||||
|
const int fixed_dims =
|
||||||
|
(allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
|
||||||
|
? 1
|
||||||
|
: dims;
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, fixed_dims == paddings.dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The first dimension of paddings must be the rank of inputs: ",
|
||||||
|
fixed_dims, " ", paddings.shape().DebugString(), " ",
|
||||||
|
resized_shape.DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, dims == paddings.dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The first dimension of paddings must be the rank of inputs: ",
|
||||||
|
dims, " ", paddings.shape().DebugString(), " ",
|
||||||
|
resized_shape.DebugString()));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, dims == 4,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Fused mirror padding only supports four-dimensional inputs, but ",
|
||||||
|
dims, " requested"));
|
||||||
|
|
||||||
|
// Compute the shape of the output tensor, and allocate it.
|
||||||
|
TensorShape padded_shape;
|
||||||
|
TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
|
||||||
|
for (int d = 0; d < dims; ++d) {
|
||||||
|
const int32 before =
|
||||||
|
paddings_matrix(d, 0); // Pad before existing elements.
|
||||||
|
const int32 after =
|
||||||
|
paddings_matrix(d, 1); // Pad after exisitng elements.
|
||||||
|
OP_REQUIRES(context, before >= 0 && after >= 0,
|
||||||
|
errors::InvalidArgument("paddings must be non-negative: ",
|
||||||
|
before, " ", after));
|
||||||
|
if (offset_ == 0) { // SYMMETRIC mode.
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, before <= resized_shape.dim_size(d) &&
|
||||||
|
after <= resized_shape.dim_size(d),
|
||||||
|
errors::InvalidArgument("paddings must be no greater "
|
||||||
|
"than the dimension size: ",
|
||||||
|
before, ", ", after, " greater than ",
|
||||||
|
resized_shape.dim_size(d)));
|
||||||
|
} else if (offset_ == 1) { // REFLECT mode.
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, before < resized_shape.dim_size(d) &&
|
||||||
|
after < resized_shape.dim_size(d),
|
||||||
|
errors::InvalidArgument("paddings must be less than"
|
||||||
|
" the dimension size: ",
|
||||||
|
before, ", ", after, " not less than ",
|
||||||
|
resized_shape.dim_size(d)));
|
||||||
|
}
|
||||||
|
padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Fused mirror padding only support spatial padding, not batches: ",
|
||||||
|
paddings.DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Fused mirror padding only support spatial padding, not channels: ",
|
||||||
|
paddings.DebugString()));
|
||||||
|
const int32 top_padding = paddings_matrix(1, 0);
|
||||||
|
const int32 bottom_padding = paddings_matrix(1, 1);
|
||||||
|
const int32 left_padding = paddings_matrix(2, 0);
|
||||||
|
const int32 right_padding = paddings_matrix(2, 1);
|
||||||
|
|
||||||
|
// Input filter is of the following dimensions:
|
||||||
|
// [ filter_rows, filter_cols, in_depth, out_depth]
|
||||||
|
const Tensor& filter = context->input(3);
|
||||||
|
|
||||||
|
// For 2D convolution, there should be 4 dimensions.
|
||||||
|
OP_REQUIRES(context, padded_shape.dims() == 4,
|
||||||
|
errors::InvalidArgument("input must be 4-dimensional",
|
||||||
|
padded_shape.DebugString()));
|
||||||
|
OP_REQUIRES(context, filter.dims() == 4,
|
||||||
|
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||||
|
filter.shape().DebugString()));
|
||||||
|
|
||||||
|
// We only check the first three dims, since the depth is accessed as an
|
||||||
|
// int64 below.
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
|
||||||
|
std::numeric_limits<int>::max()),
|
||||||
|
errors::InvalidArgument("filter too large"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last dimension for input is in_depth. It must be the same as the
|
||||||
|
// filter's in_depth.
|
||||||
|
const int64 in_depth = padded_shape.dim_size(3);
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, in_depth == filter.dim_size(2),
|
||||||
|
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||||
|
in_depth, " vs ", filter.dim_size(2)));
|
||||||
|
|
||||||
|
// The last dimension for filter is out_depth.
|
||||||
|
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||||
|
|
||||||
|
// The second dimension for input is rows/height.
|
||||||
|
// The first dimension for filter is rows/height.
|
||||||
|
const int64 padded_rows_raw = padded_shape.dim_size(1);
|
||||||
|
OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw,
|
||||||
|
std::numeric_limits<int>::max()),
|
||||||
|
errors::InvalidArgument("Input rows too large"));
|
||||||
|
const int padded_rows = static_cast<int>(padded_rows_raw);
|
||||||
|
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||||
|
const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
|
||||||
|
|
||||||
|
// The third dimension for input is columns/width.
|
||||||
|
// The second dimension for filter is columns/width.
|
||||||
|
const int64 padded_cols_raw = padded_shape.dim_size(2);
|
||||||
|
OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw,
|
||||||
|
std::numeric_limits<int>::max()),
|
||||||
|
errors::InvalidArgument("Input cols too large"));
|
||||||
|
const int padded_cols = static_cast<int>(padded_cols_raw);
|
||||||
|
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||||
|
const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
|
||||||
|
|
||||||
|
// The first dimension for input is batch.
|
||||||
|
const int64 batch_raw = padded_shape.dim_size(0);
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
|
||||||
|
errors::InvalidArgument("batch is too large"));
|
||||||
|
const int batch = static_cast<int>(batch_raw);
|
||||||
|
|
||||||
|
// For now we take the stride from the second and third dimensions only (we
|
||||||
|
// do not support striding on the batch or depth dimension).
|
||||||
|
const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
|
||||||
|
const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
|
||||||
|
|
||||||
|
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
|
||||||
|
padding_, &out_rows, &pad_rows));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
|
||||||
|
padding_, &out_cols, &pad_cols));
|
||||||
|
TensorShape out_shape =
|
||||||
|
ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
|
||||||
|
OP_REQUIRES(context, (out_shape.num_elements() > 0),
|
||||||
|
errors::InvalidArgument("Output tensor can't be empty"));
|
||||||
|
|
||||||
|
// Output tensor is of the following dimensions:
|
||||||
|
// [ in_batch, out_rows, out_cols, out_depth ]
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||||
|
|
||||||
|
VLOG(2) << "Conv2D: in_depth = " << in_depth
|
||||||
|
<< ", padded_cols = " << padded_cols
|
||||||
|
<< ", filter_cols = " << filter_cols
|
||||||
|
<< ", padded_rows = " << padded_rows
|
||||||
|
<< ", filter_rows = " << filter_rows
|
||||||
|
<< ", stride_rows = " << stride_rows
|
||||||
|
<< ", stride_cols = " << stride_cols
|
||||||
|
<< ", out_depth = " << out_depth;
|
||||||
|
|
||||||
|
// If there is nothing to compute, return.
|
||||||
|
if (out_shape.num_elements() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TConvFunctor conv_functor;
|
||||||
|
conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
|
||||||
|
padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
|
||||||
|
filter_cols, out_depth, stride_rows, stride_cols, padding_,
|
||||||
|
output->flat<T>().data(), out_rows, out_cols, st, top_padding,
|
||||||
|
bottom_padding, left_padding, right_padding, offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int32> strides_;
|
||||||
|
Padding padding_;
|
||||||
|
bool align_corners_;
|
||||||
|
int offset_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_FUSED(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("FusedResizeAndPadConv2D") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
FusedResizeConv2DUsingGemmOp< \
|
||||||
|
T, \
|
||||||
|
FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
|
||||||
|
|
||||||
|
TF_CALL_float(REGISTER_FUSED);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class FusedResizePadConvOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
void HandwrittenConv() {
|
||||||
|
const int stride = 1;
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_INT32))
|
||||||
|
.Input(FakeInput(DT_INT32))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Attr("T", DT_FLOAT)
|
||||||
|
.Attr("resize_align_corners", false)
|
||||||
|
.Attr("mode", "REFLECT")
|
||||||
|
.Attr("strides", {1, stride, stride, 1})
|
||||||
|
.Attr("padding", "SAME")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_EXPECT_OK(InitOp());
|
||||||
|
const int depth = 1;
|
||||||
|
const int image_width = 4;
|
||||||
|
const int image_height = 3;
|
||||||
|
const int image_batch_count = 1;
|
||||||
|
// The image matrix is:
|
||||||
|
// | 1 | 2 | 3 | 4 |
|
||||||
|
// | 5 | 6 | 7 | 8 |
|
||||||
|
// | 9 | 10 | 11 | 12 |
|
||||||
|
Tensor image(DT_FLOAT,
|
||||||
|
{image_batch_count, image_height, image_width, depth});
|
||||||
|
test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
|
||||||
|
// The filter matrix is:
|
||||||
|
// | 1 | 4 | 7 |
|
||||||
|
// | 2 | 5 | 8 |
|
||||||
|
// | 3 | 6 | 9 |
|
||||||
|
const int filter_size = 3;
|
||||||
|
const int filter_count = 1;
|
||||||
|
Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
|
||||||
|
test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
|
||||||
|
|
||||||
|
const int resized_width = image_width;
|
||||||
|
const int resized_height = image_height;
|
||||||
|
|
||||||
|
const int top_padding = 0;
|
||||||
|
const int bottom_padding = 0;
|
||||||
|
const int left_padding = 0;
|
||||||
|
const int right_padding = 0;
|
||||||
|
|
||||||
|
AddInputFromArray<float>(image.shape(), image.flat<float>());
|
||||||
|
AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
|
||||||
|
AddInputFromArray<int32>(
|
||||||
|
TensorShape({4, 2}),
|
||||||
|
{0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
|
||||||
|
AddInputFromArray<float>(filter.shape(), filter.flat<float>());
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
|
||||||
|
// the input set to zero because we're using the 'SAME' padding mode.
|
||||||
|
// The calculations behind the expected output are:
|
||||||
|
// (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
|
||||||
|
// (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
|
||||||
|
// (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
|
||||||
|
// (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
|
||||||
|
// (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
|
||||||
|
// (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
|
||||||
|
// (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
|
||||||
|
// (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
|
||||||
|
// (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
|
||||||
|
// (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
|
||||||
|
// (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
|
||||||
|
// (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
|
||||||
|
// This means we should end up with this matrix:
|
||||||
|
// | 105 | 150 | 183 | 95 |
|
||||||
|
// | 235 | 312 | 357 | 178 |
|
||||||
|
// | 187 | 234 | 261 | 121 |
|
||||||
|
const int expected_width = image_width;
|
||||||
|
const int expected_height = image_height * filter_count;
|
||||||
|
Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
|
||||||
|
expected_width, filter_count}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
|
||||||
|
const Tensor& output = *GetOutput(0);
|
||||||
|
test::ExpectTensorNear<float>(expected, output, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompareFusedAndSeparate(int input_width, int input_height,
|
||||||
|
int input_depth, int resize_width,
|
||||||
|
int resize_height, int y_padding, int x_padding,
|
||||||
|
int filter_size, int filter_count,
|
||||||
|
bool resize_align_corners, string pad_mode,
|
||||||
|
int stride, string padding) {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const size_t input_data_size = input_height * input_width * input_depth;
|
||||||
|
Tensor input_data(DT_FLOAT,
|
||||||
|
TensorShape({1, input_height, input_width, input_depth}));
|
||||||
|
for (int i = 0; i < input_data_size; ++i) {
|
||||||
|
input_data.flat<float>()(i) = i + 1.0f;
|
||||||
|
}
|
||||||
|
Output input =
|
||||||
|
Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
const size_t filter_data_size =
|
||||||
|
filter_size * filter_size * filter_count * input_depth;
|
||||||
|
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
|
||||||
|
input_depth, filter_count}));
|
||||||
|
for (int i = 0; i < filter_data_size; ++i) {
|
||||||
|
filter_data.flat<float>()(i) = i + 1.0f;
|
||||||
|
}
|
||||||
|
Output filter =
|
||||||
|
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
|
||||||
|
|
||||||
|
Output resize_size =
|
||||||
|
Const(root.WithOpName("resize_size"), {resize_height, resize_width});
|
||||||
|
Output resize =
|
||||||
|
ResizeBilinear(root.WithOpName("resize"), input, resize_size,
|
||||||
|
ResizeBilinear::AlignCorners(resize_align_corners));
|
||||||
|
Output paddings =
|
||||||
|
Const(root.WithOpName("paddings"),
|
||||||
|
{{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
|
||||||
|
Output mirror_pad =
|
||||||
|
MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
|
||||||
|
Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
|
||||||
|
{1, stride, stride, 1}, padding);
|
||||||
|
|
||||||
|
Output fused_conv = FusedResizeAndPadConv2D(
|
||||||
|
root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
|
||||||
|
pad_mode, {1, stride, stride, 1}, padding,
|
||||||
|
FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
|
||||||
|
|
||||||
|
tensorflow::GraphDef graph;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::Session> session(
|
||||||
|
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||||
|
TF_ASSERT_OK(session->Create(graph));
|
||||||
|
|
||||||
|
std::vector<Tensor> unfused_tensors;
|
||||||
|
TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
|
||||||
|
|
||||||
|
std::vector<Tensor> fused_tensors;
|
||||||
|
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
|
||||||
|
CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
|
||||||
|
CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
|
||||||
|
CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
|
||||||
|
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
|
||||||
|
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
|
||||||
|
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
|
||||||
|
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||||
|
"VALID");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
|
||||||
|
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
|
||||||
|
CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
|
||||||
|
CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
|
||||||
|
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
|
||||||
|
CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||||
|
"SAME");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -56,14 +56,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||||
|
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||||
|
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#define USE_ACCELERATE_GEMM
|
|
||||||
#endif // __APPLE__
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -189,87 +188,6 @@ class ReferenceConvFunctor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// A readable but slow implementation of matrix multiplication, useful for
|
|
||||||
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
|
||||||
// the Im2ColConvFunctor template definition inside the op registration to
|
|
||||||
// enable. Assumes row-major ordering of the values in memory.
|
|
||||||
template <class T1, class T2, class T3>
|
|
||||||
class ReferenceGemmFunctor {
|
|
||||||
public:
|
|
||||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
|
||||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
|
||||||
const size_t a_i_stride = lda;
|
|
||||||
const size_t a_l_stride = 1;
|
|
||||||
const size_t b_j_stride = 1;
|
|
||||||
const size_t b_l_stride = ldb;
|
|
||||||
const size_t c_i_stride = ldc;
|
|
||||||
const size_t c_j_stride = 1;
|
|
||||||
size_t i, j, l;
|
|
||||||
for (j = 0; j < n; j++) {
|
|
||||||
for (i = 0; i < m; i++) {
|
|
||||||
T3 total(0);
|
|
||||||
for (l = 0; l < k; l++) {
|
|
||||||
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
|
||||||
const T1 a_value = a[a_index];
|
|
||||||
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
|
||||||
const T2 b_value = b[b_index];
|
|
||||||
total += (a_value * b_value);
|
|
||||||
}
|
|
||||||
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
|
||||||
c[c_index] = total;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Uses the optimized Eigen library to implement the matrix multiplication
|
|
||||||
// required by the Im2ColConvFunctor class. We supply the two input and one
|
|
||||||
// output types so that the accumulator can potentially be higher-precision than
|
|
||||||
// the inputs, even though we don't currently take advantage of this.
|
|
||||||
template <class T1, class T2, class T3>
|
|
||||||
class FastGemmFunctor {
|
|
||||||
public:
|
|
||||||
// Convenience wrappers for the Eigen matrix types we'll be using.
|
|
||||||
typedef Eigen::Map<
|
|
||||||
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
|
||||||
ConstMatrixT1;
|
|
||||||
typedef Eigen::Map<
|
|
||||||
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
|
||||||
ConstMatrixT2;
|
|
||||||
typedef Eigen::Map<
|
|
||||||
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
|
||||||
MatrixT3;
|
|
||||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
|
||||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
|
||||||
ConstMatrixT1 a_matrix(a, m, k);
|
|
||||||
ConstMatrixT2 b_matrix(b, k, n);
|
|
||||||
MatrixT3 c_matrix(c, m, n);
|
|
||||||
c_matrix.noalias() = a_matrix * b_matrix;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
|
||||||
// get a performance boost for float.
|
|
||||||
#if defined(USE_ACCELERATE_GEMM)
|
|
||||||
template <>
|
|
||||||
class FastGemmFunctor<float, float, float> {
|
|
||||||
public:
|
|
||||||
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
|
||||||
const float* b, size_t ldb, float* c, size_t ldc) {
|
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
|
||||||
lda, b, ldb, 0.0f, c, ldc);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#endif // USE_ACCELERATE_GEMM
|
|
||||||
|
|
||||||
// Used to keep track of persistent memory buffers used within the op.
|
|
||||||
template <class T, size_t size>
|
|
||||||
struct Im2ColBufferResource : public ResourceBase {
|
|
||||||
mutex mu;
|
|
||||||
T data[size];
|
|
||||||
string DebugString() { return "Im2ColBufferResource"; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Implements convolution as a two stage process, first packing the patches of
|
// Implements convolution as a two stage process, first packing the patches of
|
||||||
// the input image into columns (im2col) and then running GEMM to produce the
|
// the input image into columns (im2col) and then running GEMM to produce the
|
||||||
// final result.
|
// final result.
|
||||||
@ -344,7 +262,6 @@ class Im2ColConvFunctor {
|
|||||||
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||||
const size_t patches_per_chunk =
|
const size_t patches_per_chunk =
|
||||||
max_chunk_size / (filter_value_count * sizeof(T1));
|
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||||
|
|
||||||
// Because memory allocation is very expensive on mobile platforms, try to
|
// Because memory allocation is very expensive on mobile platforms, try to
|
||||||
// allocate a persistent buffer that will be kept around between calls. We
|
// allocate a persistent buffer that will be kept around between calls. We
|
||||||
// use TensorFlow's resource management to ensure that the memory will be
|
// use TensorFlow's resource management to ensure that the memory will be
|
||||||
|
@ -99,13 +99,15 @@ struct scalar_sqrt_gradient_op {
|
|||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||||
operator()(const T& output, const T& output_gradient) const {
|
operator()(const T& output, const T& output_gradient) const {
|
||||||
return static_cast<T>(0.5) * output_gradient / output;
|
const T out_conj = numext::conj(output);
|
||||||
|
return static_cast<T>(0.5) * output_gradient / out_conj;
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||||
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
||||||
return pdiv(pmul(const_half, output_gradient), output);
|
const Packet out_conj = pconj(output);
|
||||||
|
return pdiv(pmul(const_half, output_gradient), out_conj);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -123,15 +125,17 @@ struct scalar_rsqrt_gradient_op {
|
|||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||||
operator()(const T& output, const T& output_gradient) const {
|
operator()(const T& output, const T& output_gradient) const {
|
||||||
return static_cast<T>(-0.5) * (output_gradient * output) *
|
const T out_conj = numext::conj(output);
|
||||||
(output * output);
|
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
||||||
|
(out_conj * out_conj);
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||||
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
||||||
return pmul(const_half,
|
const Packet out_conj = pconj(output);
|
||||||
pmul(pmul(output_gradient, output), pmul(output, output)));
|
return pmul(const_half, pmul(pmul(output_gradient, out_conj),
|
||||||
|
pmul(out_conj, out_conj)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -35,21 +35,49 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& images = context->input(0);
|
const Tensor& images = context->input(0);
|
||||||
const Tensor& boxes = context->input(1);
|
const Tensor& boxes = context->input(1);
|
||||||
|
const int64 depth = images.dim_size(3);
|
||||||
|
|
||||||
OP_REQUIRES(context, images.dims() == 4,
|
OP_REQUIRES(context, images.dims() == 4,
|
||||||
errors::InvalidArgument("The rank of the images should be 4"));
|
errors::InvalidArgument("The rank of the images should be 4"));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, boxes.dims() == 3,
|
context, boxes.dims() == 3,
|
||||||
errors::InvalidArgument("The rank of the boxes tensor should be 3"));
|
errors::InvalidArgument("The rank of the boxes tensor should be 3"));
|
||||||
|
|
||||||
OP_REQUIRES(context, images.dim_size(0) == boxes.dim_size(0),
|
OP_REQUIRES(context, images.dim_size(0) == boxes.dim_size(0),
|
||||||
errors::InvalidArgument("The batch sizes should be the same"));
|
errors::InvalidArgument("The batch sizes should be the same"));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, depth == 4 || depth == 1 || depth == 3,
|
||||||
|
errors::InvalidArgument("Channel depth should be either 1 (GRY), "
|
||||||
|
"3 (RGB), or 4 (RGBA)"));
|
||||||
|
|
||||||
const int64 batch_size = images.dim_size(0);
|
const int64 batch_size = images.dim_size(0);
|
||||||
const int64 height = images.dim_size(1);
|
const int64 height = images.dim_size(1);
|
||||||
const int64 width = images.dim_size(2);
|
const int64 width = images.dim_size(2);
|
||||||
const int64 depth = images.dim_size(3);
|
const int64 color_table_length = 10;
|
||||||
|
|
||||||
|
// 0: yellow
|
||||||
|
// 1: blue
|
||||||
|
// 2: red
|
||||||
|
// 3: lime
|
||||||
|
// 4: purple
|
||||||
|
// 5: olive
|
||||||
|
// 6: maroon
|
||||||
|
// 7: navy blue
|
||||||
|
// 8: aqua
|
||||||
|
// 9: fuchsia
|
||||||
|
float color_table[color_table_length][4] = {
|
||||||
|
{1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1},
|
||||||
|
{0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1},
|
||||||
|
{0, 1, 1, 1}, {1, 0, 1, 1},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reset first color channel to 1 if image is GRY.
|
||||||
|
// For GRY images, this means all bounding boxes will be white.
|
||||||
|
if (depth == 1) {
|
||||||
|
for (int64 i = 0; i < color_table_length; i++) {
|
||||||
|
color_table[i][0] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
Tensor* output;
|
Tensor* output;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context,
|
context,
|
||||||
@ -62,8 +90,8 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
for (int64 b = 0; b < batch_size; ++b) {
|
for (int64 b = 0; b < batch_size; ++b) {
|
||||||
const int64 num_boxes = boxes.dim_size(1);
|
const int64 num_boxes = boxes.dim_size(1);
|
||||||
const auto tboxes = boxes.tensor<T, 3>();
|
const auto tboxes = boxes.tensor<T, 3>();
|
||||||
|
|
||||||
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
||||||
|
int64 color_index = bb % color_table_length;
|
||||||
const int64 min_box_row =
|
const int64 min_box_row =
|
||||||
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
||||||
const int64 min_box_row_clamp =
|
const int64 min_box_row_clamp =
|
||||||
@ -122,22 +150,34 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
// Draw top line.
|
// Draw top line.
|
||||||
if (min_box_row >= 0) {
|
if (min_box_row >= 0) {
|
||||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||||
canvas(b, min_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
for (int64 c = 0; c < depth; c++) {
|
||||||
|
canvas(b, min_box_row, j, c) =
|
||||||
|
static_cast<T>(color_table[color_index][c]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Draw bottom line.
|
// Draw bottom line.
|
||||||
if (max_box_row < height) {
|
if (max_box_row < height) {
|
||||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||||
canvas(b, max_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
for (int64 c = 0; c < depth; c++) {
|
||||||
|
canvas(b, max_box_row, j, c) =
|
||||||
|
static_cast<T>(color_table[color_index][c]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Draw left line.
|
// Draw left line.
|
||||||
if (min_box_col >= 0) {
|
if (min_box_col >= 0) {
|
||||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||||
canvas(b, i, min_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
for (int64 c = 0; c < depth; c++) {
|
||||||
|
canvas(b, i, min_box_col, c) =
|
||||||
|
static_cast<T>(color_table[color_index][c]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Draw right line.
|
// Draw right line.
|
||||||
if (max_box_col < width) {
|
if (max_box_col < width) {
|
||||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||||
canvas(b, i, max_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
for (int64 c = 0; c < depth; c++) {
|
||||||
|
canvas(b, i, max_box_col, c) =
|
||||||
|
static_cast<T>(color_table[color_index][c]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ class GatherNdOp : public OpKernel {
|
|||||||
"index innermost dimension length must be <= params rank; saw: ",
|
"index innermost dimension length must be <= params rank; saw: ",
|
||||||
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
|
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
|
||||||
|
|
||||||
TensorShape indices_shape(indices.shape());
|
const TensorShape& indices_shape(indices.shape());
|
||||||
const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
|
const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||||
|
|
||||||
// Check that we have enough index space
|
// Check that we have enough index space
|
||||||
@ -79,7 +79,7 @@ class GatherNdOp : public OpKernel {
|
|||||||
N_result *= indices_shape.dim_size(i);
|
N_result *= indices_shape.dim_size(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShape params_shape(params.shape());
|
const TensorShape& params_shape(params.shape());
|
||||||
Index total_nd = params_shape.dims();
|
Index total_nd = params_shape.dims();
|
||||||
|
|
||||||
TensorShape result_shape(indices_shape);
|
TensorShape result_shape(indices_shape);
|
||||||
|
105
tensorflow/core/kernels/gemm_functors.h
Normal file
105
tensorflow/core/kernels/gemm_functors.h
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is a set of different implementations for the basic matrix by matrix
|
||||||
|
// multiply function, commonly known as GEMM after the BLAS library's naming.
|
||||||
|
// Having a standard interface enables us to swap out implementations on
|
||||||
|
// different platforms, to make sure we're using the optimal version. They are
|
||||||
|
// implemented as C++ template functors, so they're easy to swap into all of the
|
||||||
|
// different kernels that use them.
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#define USE_ACCELERATE_GEMM
|
||||||
|
#endif // __APPLE__
|
||||||
|
|
||||||
|
// A readable but slow implementation of matrix multiplication, useful for
|
||||||
|
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
||||||
|
// the Im2ColConvFunctor template definition inside the op registration to
|
||||||
|
// enable. Assumes row-major ordering of the values in memory.
|
||||||
|
template <class T1, class T2, class T3>
|
||||||
|
class ReferenceGemmFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||||
|
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||||
|
const size_t a_i_stride = lda;
|
||||||
|
const size_t a_l_stride = 1;
|
||||||
|
const size_t b_j_stride = 1;
|
||||||
|
const size_t b_l_stride = ldb;
|
||||||
|
const size_t c_i_stride = ldc;
|
||||||
|
const size_t c_j_stride = 1;
|
||||||
|
size_t i, j, l;
|
||||||
|
for (j = 0; j < n; j++) {
|
||||||
|
for (i = 0; i < m; i++) {
|
||||||
|
T3 total(0);
|
||||||
|
for (l = 0; l < k; l++) {
|
||||||
|
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
||||||
|
const T1 a_value = a[a_index];
|
||||||
|
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
||||||
|
const T2 b_value = b[b_index];
|
||||||
|
total += (a_value * b_value);
|
||||||
|
}
|
||||||
|
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
||||||
|
c[c_index] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Uses the optimized Eigen library to implement the matrix multiplication
|
||||||
|
// required by the Im2ColConvFunctor class. We supply the two input and one
|
||||||
|
// output types so that the accumulator can potentially be higher-precision than
|
||||||
|
// the inputs, even though we don't currently take advantage of this.
|
||||||
|
template <class T1, class T2, class T3>
|
||||||
|
class FastGemmFunctor {
|
||||||
|
public:
|
||||||
|
// Convenience wrappers for the Eigen matrix types we'll be using.
|
||||||
|
typedef Eigen::Map<
|
||||||
|
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||||
|
ConstMatrixT1;
|
||||||
|
typedef Eigen::Map<
|
||||||
|
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||||
|
ConstMatrixT2;
|
||||||
|
typedef Eigen::Map<
|
||||||
|
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||||
|
MatrixT3;
|
||||||
|
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||||
|
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||||
|
ConstMatrixT1 a_matrix(a, m, k);
|
||||||
|
ConstMatrixT2 b_matrix(b, k, n);
|
||||||
|
MatrixT3 c_matrix(c, m, n);
|
||||||
|
c_matrix.noalias() = a_matrix * b_matrix;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
||||||
|
// get a performance boost for float.
|
||||||
|
#if defined(USE_ACCELERATE_GEMM)
|
||||||
|
template <>
|
||||||
|
class FastGemmFunctor<float, float, float> {
|
||||||
|
public:
|
||||||
|
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
||||||
|
const float* b, size_t ldb, float* c, size_t ldc) {
|
||||||
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
||||||
|
lda, b, ldb, 0.0f, c, ldc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // USE_ACCELERATE_GEMM
|
@ -49,12 +49,13 @@ struct ImageResizerState {
|
|||||||
explicit ImageResizerState(bool align_corners)
|
explicit ImageResizerState(bool align_corners)
|
||||||
: align_corners_(align_corners) {}
|
: align_corners_(align_corners) {}
|
||||||
|
|
||||||
// ValidateAndCreateOutput checks the bounds on the input tensors
|
// ValidateAndCalculateOutputSize checks the bounds on the input tensors
|
||||||
// and requested size, sets up some of the resizing state such as the
|
// and requested size, sets up some of the resizing state such as the
|
||||||
// height_scale and width_scale, and allocates the output.
|
// height_scale and width_scale, and calculates the output size.
|
||||||
// If any of these operations fails, it sets an error status in
|
// If any of these operations fails, it sets an error status in
|
||||||
// the context, which the caller must check.
|
// the context, which the caller must check.
|
||||||
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
void ValidateAndCalculateOutputSize(OpKernelContext* context,
|
||||||
|
const Tensor& input) {
|
||||||
OP_REQUIRES(context, input.dims() == 4,
|
OP_REQUIRES(context, input.dims() == 4,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4-dimensional",
|
||||||
input.shape().DebugString()));
|
input.shape().DebugString()));
|
||||||
@ -87,12 +88,18 @@ struct ImageResizerState {
|
|||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
|
context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
|
||||||
errors::InvalidArgument("input image must be of non-zero size"));
|
errors::InvalidArgument("input image must be of non-zero size"));
|
||||||
|
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||||
|
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates all the required variables, and allocates the output.
|
||||||
|
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
||||||
|
ValidateAndCalculateOutputSize(context, input);
|
||||||
|
if (!context->status().ok()) return;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(
|
OP_REQUIRES_OK(context, context->allocate_output(
|
||||||
0, TensorShape({input.dim_size(0), out_height,
|
0, TensorShape({input.dim_size(0), out_height,
|
||||||
out_width, input.dim_size(3)}),
|
out_width, input.dim_size(3)}),
|
||||||
&output));
|
&output));
|
||||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
|
||||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 batch_size;
|
int64 batch_size;
|
||||||
|
@ -272,7 +272,7 @@ class MaxPoolingGradOp : public OpKernel {
|
|||||||
OP_REQUIRES(context, out_backprop.dims() == 4,
|
OP_REQUIRES(context, out_backprop.dims() == 4,
|
||||||
errors::InvalidArgument("out_backprop must be 4-dimensional"));
|
errors::InvalidArgument("out_backprop must be 4-dimensional"));
|
||||||
|
|
||||||
TensorShape output_shape = tensor_in.shape();
|
const TensorShape& output_shape = tensor_in.shape();
|
||||||
|
|
||||||
Tensor tensor_out_dup;
|
Tensor tensor_out_dup;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
|
@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test {
|
|||||||
test::SetOutputAttrs(params_.get(), &attrs);
|
test::SetOutputAttrs(params_.get(), &attrs);
|
||||||
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
|
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
|
||||||
params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
|
params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
|
||||||
|
params_.get()->resource_manager = device_.get()->resource_manager();
|
||||||
|
|
||||||
context_.reset(new OpKernelContext(params_.get()));
|
context_.reset(new OpKernelContext(params_.get()));
|
||||||
device_->Compute(kernel_.get(), context_.get());
|
device_->Compute(kernel_.get(), context_.get());
|
||||||
|
@ -34,6 +34,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||||
|
_Pragma("GCC diagnostic push") \
|
||||||
|
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||||
|
#else
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
@ -355,47 +365,23 @@ class RandomGammaOp : public OpKernel {
|
|||||||
// Several calculations can be done on a per-alpha basis.
|
// Several calculations can be done on a per-alpha basis.
|
||||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||||
|
|
||||||
if (alpha < 0.3) {
|
DISABLE_FLOAT_EQUALITY_WARNING
|
||||||
// For very small alpha, we use the log-space algorithm proposed in
|
if (alpha == double(1.0)) {
|
||||||
// "Simulating from a gamma distribution with small shape parameter",
|
ENABLE_FLOAT_EQUALITY_WARNING
|
||||||
// http://arxiv.org/abs/1302.1884
|
// Sample from an exponential distribution.
|
||||||
const double lambda = 1 / alpha - 1;
|
|
||||||
const double w = alpha / (M_E /* exp(1) */ * (1 - alpha));
|
|
||||||
const double r = 1 / (1 + w);
|
|
||||||
|
|
||||||
// Compute the rest of the samples for the current alpha value.
|
|
||||||
for (int64 sample_idx = output_idx % num_samples;
|
for (int64 sample_idx = output_idx % num_samples;
|
||||||
sample_idx < num_samples && output_idx < limit_output;
|
sample_idx < num_samples && output_idx < limit_output;
|
||||||
sample_idx++, output_idx++) {
|
sample_idx++, output_idx++) {
|
||||||
// Since each sample may use a variable number of normal/uniform
|
// As we want data stable regardless of sharding
|
||||||
// samples, and we want data stable regardless of sharding
|
|
||||||
// (including eventually on GPU), we skip on a per-sample basis.
|
// (including eventually on GPU), we skip on a per-sample basis.
|
||||||
PhiloxRandom gen = rng;
|
PhiloxRandom gen = rng;
|
||||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||||
short uniform_remaining = 0;
|
short uniform_remaining = 0;
|
||||||
|
|
||||||
// Keep trying until we don't reject a sample. In practice, we
|
|
||||||
// expect a low rejection rate.
|
|
||||||
while (true) {
|
|
||||||
UNIFORM(u);
|
UNIFORM(u);
|
||||||
double z;
|
const double res = -log(1.0 - u);
|
||||||
if (u <= r) {
|
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||||
z = -log(u / r);
|
} // for (sample_idx)
|
||||||
} else {
|
} else { // if alpha != 1.0
|
||||||
UNIFORM(v);
|
|
||||||
z = log(v) / lambda;
|
|
||||||
}
|
|
||||||
double eta = z >= 0 ? exp(-z) : w * lambda * exp(lambda * z);
|
|
||||||
UNIFORM(v);
|
|
||||||
double h = exp(-z - exp(-z / alpha));
|
|
||||||
if (h > eta * v) {
|
|
||||||
samples_alpha_offset[sample_idx * num_alphas] =
|
|
||||||
static_cast<T>(exp(-z / alpha));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} // while: true
|
|
||||||
} // for: sample_idx
|
|
||||||
} else { // so, alpha >= 0.3
|
|
||||||
// Transformation-rejection from pairs of uniform and normal random
|
// Transformation-rejection from pairs of uniform and normal random
|
||||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||||
//
|
//
|
||||||
@ -454,7 +440,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
} // while: true
|
} // while: true
|
||||||
} // for: sample_idx
|
} // for: sample_idx
|
||||||
} // if: alpha < 0.3
|
} // if (alpha == 1.0)
|
||||||
} // for: output_idx
|
} // for: output_idx
|
||||||
}; // DoWork
|
}; // DoWork
|
||||||
#undef UNIFORM
|
#undef UNIFORM
|
||||||
@ -463,9 +449,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
||||||
// each = ~60.
|
// each = ~60.
|
||||||
// All of this /0.95 due to the rejection possibility = ~85.
|
// All of this /0.95 due to the rejection possibility = ~85.
|
||||||
// All of this * ~2 to incorporate possibility of the log/exp branch for
|
static const int kElementCost = 85 + 2 * Normal::kElementCost +
|
||||||
// low-alpha. (1 log, 4 exp, 3/, 3*)
|
|
||||||
static const int kElementCost = 170 + 2 * Normal::kElementCost +
|
|
||||||
Uniform::kElementCost +
|
Uniform::kElementCost +
|
||||||
3 * PhiloxRandom::kElementCost;
|
3 * PhiloxRandom::kElementCost;
|
||||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
@ -65,7 +65,7 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
|||||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||||
"tensor_half"};
|
"tensor_half", "tensor_float_empty"};
|
||||||
|
|
||||||
// We first need to write a tensor using the save_op
|
// We first need to write a tensor using the save_op
|
||||||
{
|
{
|
||||||
@ -164,6 +164,11 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
|||||||
return static_cast<Eigen::half>(x) / Eigen::half(5);
|
return static_cast<Eigen::half>(x) / Eigen::half(5);
|
||||||
});
|
});
|
||||||
inputs.push_back({nullptr, &input_14});
|
inputs.push_back({nullptr, &input_14});
|
||||||
|
// Input #15 is a 2-d empty float tensor
|
||||||
|
Tensor input_15 = MakeInput<float>(TensorShape({2, 0}), [](int x) -> float {
|
||||||
|
return static_cast<float>(x) / 10;
|
||||||
|
});
|
||||||
|
inputs.push_back({nullptr, &input_15});
|
||||||
OpKernelContext::Params params;
|
OpKernelContext::Params params;
|
||||||
params.device = device.get();
|
params.device = device.get();
|
||||||
params.frame_iter = FrameAndIter(0, 0);
|
params.frame_iter = FrameAndIter(0, 0);
|
||||||
@ -341,6 +346,15 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
|||||||
output->flat<Eigen::half>()(i));
|
output->flat<Eigen::half>()(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// The 2-d empty float tensor
|
||||||
|
{
|
||||||
|
MakeRestoreOp(DT_FLOAT);
|
||||||
|
(*mutable_input(1).tensor).scalar<string>()() = tensor_names[13];
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor* output = GetOutput(0);
|
||||||
|
TensorShape expected({2, 0});
|
||||||
|
EXPECT_TRUE(output->shape().IsSameSize(expected));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class RestoreSliceOpTest : public OpsTestBase {
|
class RestoreSliceOpTest : public OpsTestBase {
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user