Merge pull request #4220 from caisq/branch_132225803
Merge Changes from Internal: Branch 132225803
This commit is contained in:
commit
7a45bc5e7f
BUILD
tensorflow
BUILD
c
cc/tutorials
contrib
bayesflow
distributions
factorization/python/ops
grid_rnn
layers
learn
BUILD
python/learn
linear_optimizer/python/ops
losses/python/losses
makefile
metrics
quantization/kernels/hexagon
rnn
training
core
BUILD
common_runtime
bfc_allocator.ccconstant_folding.cccopy_tensor.ccdevice_set.ccdirect_session.ccdirect_session.hdirect_session_test.ccfunction_test.cc
gpu
simple_graph_execution_state.ccsimple_graph_execution_state.hsimple_placer.ccsimple_placer_test.ccdistributed_runtime
framework
cost_graph.protofunction.ccfunction.hfunction_testlib.ccop_def_util.ccop_kernel_test.cctensor_shape.htensor_slice.htensor_slice_test.cc
graph
optimizer_cse_test.ccquantize_training.ccshape_refiner.ccshape_refiner.hshape_refiner_test.ccsubgraph.cc
kernels
@ -29,6 +29,15 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_arm64",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"android_cpu": "arm64-v8a",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "darwin",
|
||||
values = {"cpu": "darwin"},
|
||||
@ -95,6 +104,7 @@ filegroup(
|
||||
"//tensorflow/contrib/ffmpeg/default:all_files",
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/layers:all_files",
|
||||
"//tensorflow/contrib/layers/kernels:all_files",
|
||||
"//tensorflow/contrib/learn:all_files",
|
||||
|
@ -87,7 +87,7 @@ TEST(CApi, AllocateTensor) {
|
||||
static void TestEncodeDecode(int line,
|
||||
const std::vector<tensorflow::string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
for (std::vector<tensorflow::int64> dims :
|
||||
for (const std::vector<tensorflow::int64>& dims :
|
||||
std::vector<std::vector<tensorflow::int64>>{
|
||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||
// Create C++ Tensor
|
||||
|
@ -37,7 +37,7 @@ namespace tensorflow {
|
||||
namespace example {
|
||||
|
||||
struct Options {
|
||||
int num_concurrent_sessions = 10; // The number of concurrent sessions
|
||||
int num_concurrent_sessions = 1; // The number of concurrent sessions
|
||||
int num_concurrent_steps = 10; // The number of concurrent steps
|
||||
int num_iterations = 100; // Each step repeats this many times
|
||||
bool use_gpu = false; // Whether to use gpu in the training
|
||||
@ -108,10 +108,11 @@ void ConcurrentSteps(const Options* opts, int session_index) {
|
||||
|
||||
// Spawn M threads for M concurrent steps.
|
||||
const int M = opts->num_concurrent_steps;
|
||||
thread::ThreadPool step_threads(Env::Default(), "trainer", M);
|
||||
std::unique_ptr<thread::ThreadPool> step_threads(
|
||||
new thread::ThreadPool(Env::Default(), "trainer", M));
|
||||
|
||||
for (int step = 0; step < M; ++step) {
|
||||
step_threads.Schedule([&session, opts, session_index, step]() {
|
||||
step_threads->Schedule([&session, opts, session_index, step]() {
|
||||
// Randomly initialize the input.
|
||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||
auto x_flat = x.flat<float>();
|
||||
@ -139,12 +140,19 @@ void ConcurrentSteps(const Options* opts, int session_index) {
|
||||
});
|
||||
}
|
||||
|
||||
// Delete the threadpool, thus waiting for all threads to complete.
|
||||
step_threads.reset(nullptr);
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
void ConcurrentSessions(const Options& opts) {
|
||||
// Spawn N threads for N concurrent sessions.
|
||||
const int N = opts.num_concurrent_sessions;
|
||||
|
||||
// At the moment our Session implementation only allows
|
||||
// one concurrently computing Session on GPU.
|
||||
CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
|
||||
|
||||
thread::ThreadPool session_threads(Env::Default(), "trainer", N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
|
||||
|
@ -23,6 +23,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/entropy_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -34,6 +35,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/monte_carlo_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -45,6 +47,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/special_math_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -56,6 +59,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -67,6 +71,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/variational_inference_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -78,6 +83,7 @@ cuda_py_test(
|
||||
srcs = ["python/kernel_tests/stochastic_tensor_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -89,6 +95,7 @@ cuda_py_test(
|
||||
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
|
@ -159,6 +159,21 @@ class NdtrGradientTest(tf.test.TestCase):
|
||||
_use_log = False
|
||||
_grid = GridSpec(min=-100., max=100., shape=[1, 2, 3, 8])
|
||||
|
||||
def assert_all_true(self, v):
|
||||
self.assertAllEqual(np.ones_like(v, dtype=np.bool), v)
|
||||
|
||||
def assert_all_false(self, v):
|
||||
self.assertAllEqual(np.zeros_like(v, dtype=np.bool), v)
|
||||
|
||||
def _test_grad_finite(self, dtype):
|
||||
with self.test_session():
|
||||
x = tf.Variable([-100., 0., 100.], dtype=dtype)
|
||||
output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x))
|
||||
grad_output = tf.gradients(output, x)
|
||||
tf.initialize_all_variables().run()
|
||||
self.assert_all_true(np.isfinite(output.eval()))
|
||||
self.assert_all_true(np.isfinite(grad_output[0].eval()))
|
||||
|
||||
def _test_grads_are_positive(self, dtype, grid_spec):
|
||||
grid = tf.convert_to_tensor(_make_grid(dtype, grid_spec))
|
||||
with self.test_session():
|
||||
@ -169,20 +184,24 @@ class NdtrGradientTest(tf.test.TestCase):
|
||||
# grad_eval.shape = (N, N), with grad_eval[i, j] the partial derivative of
|
||||
# the ith output point w.r.t. the jth grid point. We only expect the
|
||||
# diagonal to be nonzero.
|
||||
# TODO(b/31131137): Replace tf.test.compute_gradient with our own custom
|
||||
# gradient evaluation to ensure we correctly handle small function delta.
|
||||
grad_eval, _ = tf.test.compute_gradient(
|
||||
grid, grid_spec.shape, output, grid_spec.shape)
|
||||
grad_eval = np.diag(grad_eval)
|
||||
|
||||
# Check for NaN separately in order to get informative failures.
|
||||
self.assertFalse(np.isnan(grad_eval).any())
|
||||
self.assertTrue((grad_eval > 0).all())
|
||||
self.assertTrue(np.isfinite(grad_eval).all())
|
||||
self.assert_all_false(np.isnan(grad_eval))
|
||||
self.assert_all_true(grad_eval > 0.)
|
||||
self.assert_all_true(np.isfinite(grad_eval))
|
||||
|
||||
def test_float32(self):
|
||||
self._test_grads_are_positive(np.float32, self._grid)
|
||||
self._test_grad_finite(np.float32)
|
||||
|
||||
def test_float64(self):
|
||||
self._test_grads_are_positive(np.float64, self._grid)
|
||||
self._test_grad_finite(np.float64)
|
||||
|
||||
|
||||
class LogNdtrGradientTest(NdtrGradientTest):
|
||||
|
@ -174,13 +174,18 @@ def log_ndtr(x, series_order=3, name=None):
|
||||
# * We use one fixed series_order for all of 'x', rather than adaptive.
|
||||
# * Our docstring properly reflects that this is an asymptotic series, not a
|
||||
# Tayor series. We also provided a correct bound on the remainder.
|
||||
|
||||
# * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
|
||||
# x=0. This happens even though the branch is unchosen because when x=0
|
||||
# the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
|
||||
# regardless of whether dy is finite. Note that the minimum is a NOP if
|
||||
# the branch is chosen.
|
||||
return math_ops.select(
|
||||
math_ops.greater(x, upper_segment),
|
||||
-_ndtr(-x), # log(1-x) ~= -x, x << 1
|
||||
math_ops.select(math_ops.greater(x, lower_segment),
|
||||
math_ops.log(_ndtr(x)),
|
||||
_log_ndtr_lower(x, series_order)))
|
||||
math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))),
|
||||
_log_ndtr_lower(math_ops.minimum(x, lower_segment),
|
||||
series_order)))
|
||||
|
||||
|
||||
def _log_ndtr_lower(x, series_order):
|
||||
|
@ -16,6 +16,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -27,6 +28,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -38,6 +40,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_diag_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -49,6 +52,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -60,6 +64,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -71,6 +76,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -89,6 +95,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/bernoulli_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -99,6 +106,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/beta_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = ["notsan"], #http://b/31216497
|
||||
@ -110,6 +118,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/binomial_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -120,6 +129,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/categorical_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -129,6 +139,7 @@ cuda_py_tests(
|
||||
name = "chi2_test",
|
||||
srcs = ["python/kernel_tests/chi2_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -140,6 +151,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/dirichlet_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -151,6 +163,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -161,6 +174,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/exponential_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -170,6 +184,7 @@ cuda_py_tests(
|
||||
name = "gamma_test",
|
||||
srcs = ["python/kernel_tests/gamma_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -180,6 +195,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/inverse_gamma_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -190,6 +206,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -200,6 +217,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/multinomial_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -211,6 +229,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/mvn_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -222,6 +241,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/mixture_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -233,6 +253,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/normal_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -244,6 +265,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/poisson_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -255,6 +277,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/student_t_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -266,6 +289,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/uniform_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
@ -277,6 +301,7 @@ cuda_py_tests(
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -286,6 +311,7 @@ cuda_py_tests(
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -296,6 +322,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/normal_conjugate_posteriors_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -306,6 +333,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/transformed_distribution_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -316,6 +344,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/distribution_util_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -327,6 +356,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/shape_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -338,6 +368,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/bijector_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
|
@ -27,6 +27,14 @@ import tensorflow as tf
|
||||
|
||||
class NormalTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(123)
|
||||
|
||||
def assertAllFinite(self, tensor):
|
||||
is_finite = np.isfinite(tensor.eval())
|
||||
all_true = np.ones_like(is_finite, dtype=np.bool)
|
||||
self.assertAllEqual(all_true, is_finite)
|
||||
|
||||
def _testParamShapes(self, sample_shape, expected):
|
||||
with self.test_session():
|
||||
param_shapes = tf.contrib.distributions.Normal.param_shapes(sample_shape)
|
||||
@ -143,21 +151,94 @@ class NormalTest(tf.test.TestCase):
|
||||
|
||||
def testNormalCDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 6
|
||||
mu = tf.constant([3.0] * batch_size)
|
||||
sigma = tf.constant([math.sqrt(10.0)] * batch_size)
|
||||
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
|
||||
batch_size = 50
|
||||
mu = self._rng.randn(batch_size)
|
||||
sigma = self._rng.rand(batch_size) + 1.0
|
||||
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||
|
||||
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||
expected_cdf = stats.norm(mu.eval(), sigma.eval()).cdf(x)
|
||||
expected_cdf = stats.norm(mu, sigma).cdf(x)
|
||||
|
||||
cdf = normal.cdf(x)
|
||||
self.assertAllClose(expected_cdf, cdf.eval())
|
||||
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
|
||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
|
||||
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
|
||||
|
||||
def testNormalSurvivalFunction(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
mu = self._rng.randn(batch_size)
|
||||
sigma = self._rng.rand(batch_size) + 1.0
|
||||
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||
|
||||
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||
expected_sf = stats.norm(mu, sigma).sf(x)
|
||||
|
||||
sf = normal.survival_function(x)
|
||||
self.assertAllClose(expected_sf, sf.eval(), atol=0)
|
||||
self.assertAllEqual(normal.batch_shape().eval(), sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape().eval(), sf.eval().shape)
|
||||
self.assertAllEqual(normal.get_batch_shape(), sf.get_shape())
|
||||
self.assertAllEqual(normal.get_batch_shape(), sf.eval().shape)
|
||||
|
||||
def testNormalLogCDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
mu = self._rng.randn(batch_size)
|
||||
sigma = self._rng.rand(batch_size) + 1.0
|
||||
x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
|
||||
|
||||
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||
expected_cdf = stats.norm(mu, sigma).logcdf(x)
|
||||
|
||||
cdf = normal.log_cdf(x)
|
||||
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
|
||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
|
||||
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
|
||||
|
||||
def testFiniteGradientAtDifficultPoints(self):
|
||||
with self.test_session():
|
||||
for dtype in [np.float32, np.float64]:
|
||||
mu = tf.Variable(dtype(0.0))
|
||||
sigma = tf.Variable(dtype(1.0))
|
||||
dist = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||
tf.initialize_all_variables().run()
|
||||
for func in [
|
||||
dist.cdf,
|
||||
dist.log_cdf,
|
||||
dist.survival_function,
|
||||
dist.log_survival_function,
|
||||
dist.log_prob,
|
||||
dist.prob]:
|
||||
x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype)
|
||||
value = func(x)
|
||||
grads = tf.gradients(value, [mu, sigma])
|
||||
|
||||
self.assertAllFinite(value)
|
||||
self.assertAllFinite(grads[0])
|
||||
self.assertAllFinite(grads[1])
|
||||
|
||||
def testNormalLogSurvivalFunction(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
mu = self._rng.randn(batch_size)
|
||||
sigma = self._rng.rand(batch_size) + 1.0
|
||||
x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
|
||||
|
||||
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
|
||||
expected_sf = stats.norm(mu, sigma).logsf(x)
|
||||
|
||||
sf = normal.log_survival_function(x)
|
||||
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
|
||||
self.assertAllEqual(normal.batch_shape().eval(), sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape().eval(), sf.eval().shape)
|
||||
self.assertAllEqual(normal.get_batch_shape(), sf.get_shape())
|
||||
self.assertAllEqual(normal.get_batch_shape(), sf.eval().shape)
|
||||
|
||||
def testNormalEntropyWithScalarInputs(self):
|
||||
# Scipy.stats.norm cannot deal with the shapes in the other test.
|
||||
with self.test_session():
|
||||
|
@ -540,6 +540,16 @@ class Distribution(BaseDistribution):
|
||||
def log_cdf(self, value, name="log_cdf"):
|
||||
"""Log cumulative distribution function.
|
||||
|
||||
Given random variable `X`, the cumulative distribution function `cdf` is:
|
||||
|
||||
```
|
||||
log_cdf(x) := Log[ P[X <= x] ]
|
||||
```
|
||||
|
||||
Often, a numerical approximation can be used for `log_cdf(x)` that yields
|
||||
a more accurate answer than simply taking the logarithm of the `cdf` when
|
||||
`x << -1`.
|
||||
|
||||
Args:
|
||||
value: `float` or `double` `Tensor`.
|
||||
name: The name to give this op.
|
||||
@ -556,6 +566,12 @@ class Distribution(BaseDistribution):
|
||||
def cdf(self, value, name="cdf"):
|
||||
"""Cumulative distribution function.
|
||||
|
||||
Given random variable `X`, the cumulative distribution function `cdf` is:
|
||||
|
||||
```
|
||||
cdf(x) := P[X <= x]
|
||||
```
|
||||
|
||||
Args:
|
||||
value: `float` or `double` `Tensor`.
|
||||
name: The name to give this op.
|
||||
@ -569,6 +585,57 @@ class Distribution(BaseDistribution):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
return self._cdf(value)
|
||||
|
||||
def log_survival_function(self, value, name="log_survival_function"):
|
||||
"""Log survival function.
|
||||
|
||||
Given random variable `X`, the survival function is defined:
|
||||
|
||||
```
|
||||
log_survival_function(x) = Log[ P[X > x] ]
|
||||
= Log[ 1 - P[X <= x] ]
|
||||
= Log[ 1 - cdf(x) ]
|
||||
```
|
||||
|
||||
Typically, different numerical approximations can be used for the log
|
||||
survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
|
||||
|
||||
Args:
|
||||
value: `float` or `double` `Tensor`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
`Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
||||
`self.dtype`.
|
||||
"""
|
||||
self._check_hasattr(self._log_survival_function)
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
return self._log_survival_function(value)
|
||||
|
||||
def survival_function(self, value, name="survival_function"):
|
||||
"""Survival function.
|
||||
|
||||
Given random variable `X`, the survival function is defined:
|
||||
|
||||
```
|
||||
survival_function(x) = P[X > x]
|
||||
= 1 - P[X <= x]
|
||||
= 1 - cdf(x).
|
||||
```
|
||||
|
||||
Args:
|
||||
value: `float` or `double` `Tensor`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
||||
`self.dtype`.
|
||||
"""
|
||||
self._check_hasattr(self._survival_function)
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
return self._survival_function(value)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""Shanon entropy in nats."""
|
||||
self._check_hasattr(self._entropy)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.bayesflow.python.ops import special_math
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
@ -169,20 +170,22 @@ class Normal(distribution.Distribution):
|
||||
|
||||
def _log_prob(self, x):
|
||||
return (-0.5 * math.log(2. * math.pi) - math_ops.log(self.sigma)
|
||||
-0.5 * math_ops.square((x - self.mu) / self.sigma))
|
||||
-0.5 * math_ops.square(self._z(x)))
|
||||
|
||||
def _prob(self, x):
|
||||
return math_ops.exp(self._log_prob(x))
|
||||
|
||||
def _log_cdf(self, x):
|
||||
return math_ops.log(self._cdf(x))
|
||||
return special_math.log_ndtr(self._z(x))
|
||||
|
||||
def _cdf(self, x):
|
||||
# TODO(ebrevdo): wrap this in a Defun with a custom Defun
|
||||
# gradient because the analytic gradient may be faster than
|
||||
# automatic differentiation.
|
||||
return (0.5 + 0.5*math_ops.erf(
|
||||
1. / (math.sqrt(2.) * self.sigma) * (x - self.mu)))
|
||||
return special_math.ndtr(self._z(x))
|
||||
|
||||
def _log_survival_function(self, x):
|
||||
return special_math.log_ndtr(-self._z(x))
|
||||
|
||||
def _survival_function(self, x):
|
||||
return special_math.ndtr(-self._z(x))
|
||||
|
||||
def _entropy(self):
|
||||
# Use broadcasting rules to calculate the full broadcast sigma.
|
||||
@ -201,6 +204,11 @@ class Normal(distribution.Distribution):
|
||||
def _mode(self):
|
||||
return self._mean()
|
||||
|
||||
def _z(self, x):
|
||||
"""Standardize input `x` to a unit normal."""
|
||||
with ops.name_scope("standardize", values=[x]):
|
||||
return (x - self.mu) / self.sigma
|
||||
|
||||
|
||||
@kullback_leibler.RegisterKL(Normal, Normal)
|
||||
def _kl_normal_normal(n_a, n_b, name=None):
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||
|
||||
|
||||
@ -166,12 +167,17 @@ class GMM(estimator.Estimator, TransformerMixin):
|
||||
self.model_dir,
|
||||
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
|
||||
|
||||
def _parse_tensor_or_dict(self, features):
|
||||
if isinstance(features, dict):
|
||||
return array_ops.concat(1, [features[k] for k in sorted(features.keys())])
|
||||
return features
|
||||
|
||||
def _get_train_ops(self, features, _):
|
||||
(_,
|
||||
_,
|
||||
losses,
|
||||
training_op) = gmm_ops.gmm(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._training_initial_clusters,
|
||||
self._num_clusters,
|
||||
self._random_seed,
|
||||
@ -187,7 +193,7 @@ class GMM(estimator.Estimator, TransformerMixin):
|
||||
model_predictions,
|
||||
_,
|
||||
_) = gmm_ops.gmm(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._training_initial_clusters,
|
||||
self._num_clusters,
|
||||
self._random_seed,
|
||||
@ -203,7 +209,7 @@ class GMM(estimator.Estimator, TransformerMixin):
|
||||
_,
|
||||
losses,
|
||||
_) = gmm_ops.gmm(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._training_initial_clusters,
|
||||
self._num_clusters,
|
||||
self._random_seed,
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||
|
||||
SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
|
||||
@ -222,12 +223,17 @@ class KMeansClustering(estimator.Estimator,
|
||||
"""Returns cluster centers."""
|
||||
return tf.contrib.framework.load_variable(self.model_dir, self.CLUSTERS)
|
||||
|
||||
def _parse_tensor_or_dict(self, features):
|
||||
if isinstance(features, dict):
|
||||
return array_ops.concat(1, [features[k] for k in sorted(features.keys())])
|
||||
return features
|
||||
|
||||
def _get_train_ops(self, features, _):
|
||||
(_,
|
||||
_,
|
||||
losses,
|
||||
training_op) = clustering_ops.KMeans(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._num_clusters,
|
||||
self._training_initial_clusters,
|
||||
self._distance_metric,
|
||||
@ -245,7 +251,7 @@ class KMeansClustering(estimator.Estimator,
|
||||
model_predictions,
|
||||
_,
|
||||
_) = clustering_ops.KMeans(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._num_clusters,
|
||||
self._training_initial_clusters,
|
||||
self._distance_metric,
|
||||
@ -263,7 +269,7 @@ class KMeansClustering(estimator.Estimator,
|
||||
_,
|
||||
losses,
|
||||
_) = clustering_ops.KMeans(
|
||||
features,
|
||||
self._parse_tensor_or_dict(features),
|
||||
self._num_clusters,
|
||||
self._training_initial_clusters,
|
||||
self._distance_metric,
|
||||
|
@ -21,6 +21,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/grid_rnn_test.py"],
|
||||
additional_deps = [
|
||||
":grid_rnn_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
|
@ -127,6 +127,7 @@ py_test(
|
||||
name = "optimizers_test",
|
||||
srcs = ["python/layers/optimizers_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"], # http://b/31223979
|
||||
deps = [
|
||||
":layers_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -181,7 +181,7 @@ class _TargetColumn(object):
|
||||
weight_tensor, shape=(-1,)))
|
||||
return weighted_loss
|
||||
|
||||
def training_loss(self, logits, target, features):
|
||||
def training_loss(self, logits, target, features, name="training_loss"):
|
||||
"""Returns training loss tensor for this head.
|
||||
|
||||
Training loss is different from the loss reported on the tensorboard as we
|
||||
@ -197,6 +197,7 @@ class _TargetColumn(object):
|
||||
target: either a tensor for labels or in multihead case, a dict of string
|
||||
to target tensor.
|
||||
features: features dict.
|
||||
name: Op name.
|
||||
|
||||
Returns:
|
||||
Loss tensor.
|
||||
@ -206,10 +207,9 @@ class _TargetColumn(object):
|
||||
|
||||
weight_tensor = self.get_weight_tensor(features)
|
||||
if weight_tensor is None:
|
||||
return math_ops.reduce_mean(loss_unweighted, name="loss")
|
||||
else:
|
||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||
return math_ops.reduce_mean(loss_weighted, name="loss")
|
||||
return math_ops.reduce_mean(loss_unweighted, name=name)
|
||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||
return math_ops.reduce_mean(loss_weighted, name=name)
|
||||
|
||||
def loss(self, logits, target, features):
|
||||
"""Returns loss tensor for this head.
|
||||
@ -233,12 +233,11 @@ class _TargetColumn(object):
|
||||
weight_tensor = self.get_weight_tensor(features)
|
||||
if weight_tensor is None:
|
||||
return math_ops.reduce_mean(loss_unweighted, name="loss")
|
||||
else:
|
||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||
return math_ops.div(
|
||||
math_ops.reduce_sum(loss_weighted),
|
||||
math_ops.to_float(math_ops.reduce_sum(weight_tensor)),
|
||||
name="loss")
|
||||
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
|
||||
return math_ops.div(
|
||||
math_ops.reduce_sum(loss_weighted),
|
||||
math_ops.to_float(math_ops.reduce_sum(weight_tensor)),
|
||||
name="loss")
|
||||
|
||||
|
||||
class _RegressionTargetColumn(_TargetColumn):
|
||||
@ -409,8 +408,10 @@ def _run_metrics(predictions, targets, metrics, weights):
|
||||
result = {}
|
||||
targets = math_ops.cast(targets, predictions.dtype)
|
||||
for name, metric in six.iteritems(metrics or {}):
|
||||
result[name] = metrics_lib.run_metric(
|
||||
metric, predictions, targets, weights=weights)
|
||||
if weights is not None:
|
||||
result[name] = metric(predictions, targets, weights=weights)
|
||||
else:
|
||||
result[name] = metric(predictions, targets)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -299,6 +299,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metric_spec_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/metric_spec_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "experiment_test",
|
||||
size = "small",
|
||||
|
@ -35,12 +35,12 @@ from tensorflow.contrib.learn.python.learn.dataframe import *
|
||||
from tensorflow.contrib.learn.python.learn.estimators import *
|
||||
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
|
||||
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
||||
from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import *
|
||||
from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError
|
||||
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -79,8 +79,6 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
||||
Both features' `value` must be a `SparseTensor`.
|
||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||
whose `value` is a `Tensor`.
|
||||
- if `feature_columns` is `None`, then `input` must contain only real
|
||||
valued `Tensor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -211,8 +209,6 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
||||
Both features' `value` must be a `SparseTensor`.
|
||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||
whose `value` is a `Tensor`.
|
||||
- if `feature_columns` is `None`, then `input` must contain only real
|
||||
valued `Tensor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -253,9 +253,11 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
logits = array_ops.reshape(
|
||||
array_ops.tile(centered_bias[0], [batch_size]),
|
||||
[batch_size, self._target_column.num_label_columns])
|
||||
training_loss = self._target_column.training_loss(logits, targets, features)
|
||||
# Learn central bias by an optimizer. 0.1 is a convervative lr for a single
|
||||
# variable.
|
||||
with ops.name_scope(None, "centered_bias", (targets, features)):
|
||||
training_loss = self._target_column.training_loss(
|
||||
logits, targets, features)
|
||||
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
|
||||
# single variable.
|
||||
return training.AdagradOptimizer(0.1).minimize(
|
||||
training_loss, var_list=centered_bias)
|
||||
|
||||
|
@ -223,10 +223,13 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
classifier.fit(input_fn=_input_fn_train, steps=100)
|
||||
scores = classifier.evaluate(input_fn=_input_fn_eval,
|
||||
steps=100)
|
||||
classifier.fit(input_fn=_input_fn_train, steps=100, monitors=(
|
||||
tf.contrib.learn.monitors.CaptureVariable(var_name='loss'),
|
||||
tf.contrib.learn.monitors.CaptureVariable(
|
||||
var_name='centered_bias/training_loss'),
|
||||
tf.contrib.learn.monitors.CaptureVariable(var_name='training_loss'),
|
||||
))
|
||||
scores = classifier.evaluate(input_fn=_input_fn_eval, steps=100)
|
||||
# If there is no weight column, model should learn y=Not(x). All examples in
|
||||
# eval data set are y=x. So if weight column is ignored, then accuracy
|
||||
# should be zero.
|
||||
@ -251,8 +254,12 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
classifier.fit(input_fn=_input_fn_train, steps=100)
|
||||
classifier.fit(input_fn=_input_fn_train, steps=100, monitors=(
|
||||
tf.contrib.learn.monitors.CaptureVariable(var_name='loss'),
|
||||
tf.contrib.learn.monitors.CaptureVariable(
|
||||
var_name='centered_bias/training_loss'),
|
||||
tf.contrib.learn.monitors.CaptureVariable(var_name='training_loss'),
|
||||
))
|
||||
scores = classifier.evaluate(input_fn=_input_fn_train, steps=100)
|
||||
# If weight column is ignored, then accuracy should be 0.25. If it's not
|
||||
# ignored, then it should be greater than 0.6.
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
@ -52,7 +53,6 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import device_setter
|
||||
from tensorflow.python.training import saver
|
||||
@ -174,6 +174,76 @@ def _get_replica_device_setter(config):
|
||||
return None
|
||||
|
||||
|
||||
def _make_metrics_ops(metrics, features, targets, predictions):
|
||||
"""Add metrics to run on features, targets, and predictions dicts or tensors.
|
||||
|
||||
`metrics` contains a specification for how to run metrics. It is a dict
|
||||
mapping friendly names to either `MetricSpec` objects, or directly to a metric
|
||||
function (assuming that predictions and targets are single tensors), or to
|
||||
a `(pred_name, metric)` tuples, which passes `predictions[pred_name]` and
|
||||
targets to `metric` (assuming targets is a single tensor).
|
||||
|
||||
Users are encouraged to use `MetricSpec` objects, which are more flexible and
|
||||
cleaner. They also lead to clearer errors.
|
||||
|
||||
Args:
|
||||
metrics: A dict mapping names to metrics specification, for example
|
||||
`MetricSpec` objects.
|
||||
features: A dict of tensors returned from an input_fn as features/inputs.
|
||||
targets: A single tensor or a dict of tensors returned from an input_fn as
|
||||
labels.
|
||||
predictions: A single tensor or a dict of tensors output from a model as
|
||||
predictions.
|
||||
|
||||
Returns:
|
||||
A dict mapping the friendly given in `metrics` to the result of calling the
|
||||
given metric function.
|
||||
|
||||
Raises:
|
||||
ValueError: If metrics specifications do not work with the type of
|
||||
features/targets/predictions provided. Mostly, a dict is given but no
|
||||
pred_name specified.
|
||||
"""
|
||||
metrics = metrics or {}
|
||||
if isinstance(targets, dict) and len(targets) == 1:
|
||||
# Unpack single target into just tensor.
|
||||
targets = targets[list(targets.keys())[0]]
|
||||
result = {}
|
||||
for name, metric in six.iteritems(metrics):
|
||||
if isinstance(metric, metric_spec.MetricSpec):
|
||||
result[name] = metric.create_metric_ops(features, targets, predictions)
|
||||
continue
|
||||
|
||||
# TODO(b/31229024): Remove the rest of this loop
|
||||
logging.warning('Please specify metrics using MetricSpec. Using bare '
|
||||
'functions or (key, fn) tuples is deprecated and support '
|
||||
'for it will be removed on Oct 1, 2016.')
|
||||
|
||||
if isinstance(name, tuple):
|
||||
# Multi-head metrics.
|
||||
if not isinstance(predictions, dict):
|
||||
raise ValueError(
|
||||
'Metrics passed provide (name, prediction), '
|
||||
'but predictions are not dict. '
|
||||
'Metrics: %s, Predictions: %s.' % (metrics, predictions))
|
||||
# Here are two options: targets are single Tensor or a dict.
|
||||
if isinstance(targets, dict) and name[1] in targets:
|
||||
# If targets are dict and the prediction name is in it, apply metric.
|
||||
result[name[0]] = metric(predictions[name[1]], targets[name[1]])
|
||||
else:
|
||||
# Otherwise pass the targets to the metric.
|
||||
result[name[0]] = metric(predictions[name[1]], targets)
|
||||
else:
|
||||
# Single head metrics.
|
||||
if isinstance(predictions, dict):
|
||||
raise ValueError(
|
||||
'Metrics passed provide only name, no prediction, '
|
||||
'but predictions are dict. '
|
||||
'Metrics: %s, Targets: %s.' % (metrics, targets))
|
||||
result[name] = metric(predictions, targets)
|
||||
return result
|
||||
|
||||
|
||||
class BaseEstimator(
|
||||
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
||||
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
|
||||
@ -389,7 +459,7 @@ class BaseEstimator(
|
||||
'The signature of the input_fn accepted by export is changing to be '
|
||||
'consistent with what\'s used by tf.Learn Estimator\'s train/evaluate. '
|
||||
'input_fn and input_feature_key will become required args, '
|
||||
'and use_deprecated_input_fn will default to False & be removed '
|
||||
'and use_deprecated_input_fn will default to False and be removed '
|
||||
'altogether.',
|
||||
use_deprecated_input_fn=True,
|
||||
input_fn=None,
|
||||
@ -470,15 +540,14 @@ class BaseEstimator(
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||
metrics: Dict of metric ops to run. If None, the default metric functions
|
||||
are used; if {}, no metrics are used. If model has one output (i.e.,
|
||||
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
|
||||
name of the metric that will show up in the logs / summaries.
|
||||
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
|
||||
- name of the metric and name of `Tensor` in the predictions to run
|
||||
this metric on. Metric ops should support streaming, e.g., returning
|
||||
metrics: Dict of metrics to run. If None, the default metric functions
|
||||
are used; if {}, no metrics are used. Otherwise, `metrics` should map
|
||||
friendly names for the metric to a `MetricSpec` object defining which
|
||||
model outputs to evaluate against which targets with which metric
|
||||
function. Metric ops should support streaming, e.g., returning
|
||||
update_op and value tensors. See more details in
|
||||
../../../../metrics/python/metrics/ops/streaming_metrics.py.
|
||||
`../../../../metrics/python/metrics/ops/streaming_metrics.py` and
|
||||
`../metric_spec.py`.
|
||||
|
||||
Returns:
|
||||
metrics: `dict` of `Tensor` objects.
|
||||
@ -782,8 +851,7 @@ class Estimator(BaseEstimator):
|
||||
model_fn=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
params=None,
|
||||
weight_column_name=None):
|
||||
params=None):
|
||||
"""Constructs an Estimator instance.
|
||||
|
||||
Args:
|
||||
@ -795,7 +863,7 @@ class Estimator(BaseEstimator):
|
||||
* `(features, targets, mode) -> (predictions, loss, train_op)`
|
||||
* `(features, targets, mode, params) -> (predictions, loss, train_op)`
|
||||
|
||||
Where
|
||||
Where
|
||||
|
||||
* `features` are single `Tensor` or `dict` of `Tensor`s
|
||||
(depending on data passed to `fit`),
|
||||
@ -816,9 +884,6 @@ class Estimator(BaseEstimator):
|
||||
config: Configuration object.
|
||||
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
||||
Keys are names of parameters, values are basic python types.
|
||||
weight_column_name: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
|
||||
Raises:
|
||||
ValueError: parameters of `model_fn` don't match `params`.
|
||||
@ -831,17 +896,10 @@ class Estimator(BaseEstimator):
|
||||
raise ValueError('Estimator\'s model_fn (%s) has less than 4 '
|
||||
'arguments, but not None params (%s) are passed.' %
|
||||
(model_fn, params))
|
||||
if (params is None and weight_column_name is None and
|
||||
'params' in model_fn_args):
|
||||
if params is None and 'params' in model_fn_args:
|
||||
logging.warning('Estimator\'s model_fn (%s) has includes params '
|
||||
'argument, but params are not passed to Estimator.',
|
||||
model_fn)
|
||||
self.weight_column_name = weight_column_name
|
||||
if weight_column_name is not None:
|
||||
if params is None:
|
||||
params = {'weight_column_name': weight_column_name}
|
||||
else:
|
||||
params['weight_column_name'] = weight_column_name
|
||||
self._model_fn = model_fn
|
||||
self.params = params
|
||||
|
||||
@ -855,11 +913,6 @@ class Estimator(BaseEstimator):
|
||||
return self._model_fn(features, targets, mode=mode)
|
||||
return self._model_fn(features, targets)
|
||||
|
||||
def _get_weight_tensor(self, features):
|
||||
if not self.weight_column_name:
|
||||
return None
|
||||
return math_ops.to_float(features[self.weight_column_name])
|
||||
|
||||
def _get_train_ops(self, features, targets):
|
||||
"""Method that builds model graph and returns trainer ops.
|
||||
|
||||
@ -887,15 +940,14 @@ class Estimator(BaseEstimator):
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||
metrics: Dict of metric ops to run. If None, the default metric functions
|
||||
are used; if {}, no metrics are used. If model has one output (i.e.,
|
||||
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
|
||||
name of the metric that will show up in the logs / summaries.
|
||||
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
|
||||
- name of the metric and name of `Tensor` in the predictions to run
|
||||
this metric on. Metric ops should support streaming, e.g., returning
|
||||
metrics: Dict of metrics to run. If None, the default metric functions
|
||||
are used; if {}, no metrics are used. Otherwise, `metrics` should map
|
||||
friendly names for the metric to a `MetricSpec` object defining which
|
||||
model outputs to evaluate against which targets with which metric
|
||||
function. Metric ops should support streaming, e.g., returning
|
||||
update_op and value tensors. See more details in
|
||||
../../../../metrics/python/metrics/ops/streaming_metrics.py.
|
||||
`../../../../metrics/python/metrics/ops/streaming_metrics.py` and
|
||||
`../metric_spec.py`.
|
||||
|
||||
Returns:
|
||||
metrics: `dict` of `Tensor` objects.
|
||||
@ -905,38 +957,7 @@ class Estimator(BaseEstimator):
|
||||
"""
|
||||
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
|
||||
result = {'loss': metrics_lib.streaming_mean(loss)}
|
||||
|
||||
weights = self._get_weight_tensor(features)
|
||||
metrics = metrics or {}
|
||||
if isinstance(targets, dict) and len(targets) == 1:
|
||||
# Unpack single target into just tensor.
|
||||
targets = targets[list(targets.keys())[0]]
|
||||
for name, metric in six.iteritems(metrics):
|
||||
if isinstance(name, tuple):
|
||||
# Multi-head metrics.
|
||||
if not isinstance(predictions, dict):
|
||||
raise ValueError(
|
||||
'Metrics passed provide (name, prediction), '
|
||||
'but predictions are not dict. '
|
||||
'Metrics: %s, Predictions: %s.' % (metrics, predictions))
|
||||
# Here are two options: targets are single Tensor or a dict.
|
||||
if isinstance(targets, dict) and name[1] in targets:
|
||||
# If targets are dict and the prediction name is in it, apply metric.
|
||||
result[name[0]] = metrics_lib.run_metric(
|
||||
metric, predictions[name[1]], targets[name[1]], weights)
|
||||
else:
|
||||
# Otherwise pass the targets to the metric.
|
||||
result[name[0]] = metrics_lib.run_metric(
|
||||
metric, predictions[name[1]], targets, weights)
|
||||
else:
|
||||
# Single head metrics.
|
||||
if isinstance(predictions, dict):
|
||||
raise ValueError(
|
||||
'Metrics passed provide only name, no prediction, '
|
||||
'but predictions are dict. '
|
||||
'Metrics: %s, Targets: %s.' % (metrics, targets))
|
||||
result[name] = metrics_lib.run_metric(
|
||||
metric, predictions, targets, weights)
|
||||
result.update(_make_metrics_ops(metrics, features, targets, predictions))
|
||||
return result
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
|
@ -44,16 +44,6 @@ def boston_input_fn(num_epochs=None):
|
||||
return features, target
|
||||
|
||||
|
||||
def boston_input_with_weight_fn():
|
||||
boston = tf.contrib.learn.datasets.load_boston()
|
||||
features = {}
|
||||
features['data'] = tf.reshape(
|
||||
tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
|
||||
target = tf.reshape(tf.constant(boston.target), [-1, 1])
|
||||
features['weight'] = tf.mul(0.5, tf.ones(target.get_shape()))
|
||||
return features, target
|
||||
|
||||
|
||||
def iris_input_fn():
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
|
||||
@ -92,42 +82,6 @@ def linear_model_fn(features, target, mode):
|
||||
return prediction, loss, train_op
|
||||
|
||||
|
||||
def linear_model_with_weights_fn(features, target, mode):
|
||||
assert mode in ('train', 'eval', 'infer')
|
||||
prediction, loss = (
|
||||
tf.contrib.learn.models.linear_regression_zero_init(
|
||||
features['data'], target)
|
||||
)
|
||||
train_op = tf.contrib.layers.optimize_loss(
|
||||
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
|
||||
learning_rate=0.1)
|
||||
return prediction, loss, train_op
|
||||
|
||||
|
||||
def linear_model_with_weights_and_params_fn(features, target, mode, params):
|
||||
assert mode in ('train', 'eval', 'infer')
|
||||
prediction, loss = (
|
||||
tf.contrib.learn.models.linear_regression_zero_init(
|
||||
features['data'], target)
|
||||
)
|
||||
train_op = tf.contrib.layers.optimize_loss(
|
||||
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
|
||||
learning_rate=params['learning_rate'])
|
||||
return prediction, loss, train_op
|
||||
|
||||
|
||||
def squared_error_weighted_sum(predictions, targets, weights=None):
|
||||
squared_error = tf.to_float(tf.square(predictions - targets))
|
||||
if weights is None:
|
||||
return tf.reduce_sum(squared_error)
|
||||
else:
|
||||
return tf.reduce_sum(tf.mul(squared_error, weights))
|
||||
|
||||
|
||||
def squared_error_no_weight(predictions, targets):
|
||||
return squared_error_weighted_sum(predictions, targets)
|
||||
|
||||
|
||||
def logistic_model_no_mode_fn(features, target):
|
||||
target = tf.one_hot(target, 3, 1, 0)
|
||||
prediction, loss = (
|
||||
@ -384,40 +338,6 @@ class EstimatorTest(tf.test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
est.fit(input_fn=other_input_fn, steps=1)
|
||||
|
||||
def testEstimatorWithWeight(self):
|
||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_with_weights_fn,
|
||||
weight_column_name='weight')
|
||||
self.assertTrue(est.params is not None)
|
||||
self.assertTrue('weight_column_name' in est.params)
|
||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
||||
scores = est.evaluate(
|
||||
input_fn=boston_input_with_weight_fn, steps=100,
|
||||
metrics={'SEWS': squared_error_weighted_sum,
|
||||
'SE': squared_error_no_weight})
|
||||
self.assertNear(scores['SEWS']*2, scores['SE'], 0.01)
|
||||
|
||||
def testEstimatorWithWeightAndParams(self):
|
||||
est = tf.contrib.learn.Estimator(
|
||||
model_fn=linear_model_with_weights_and_params_fn,
|
||||
params={'learning_rate': 0.01},
|
||||
weight_column_name='weight')
|
||||
self.assertTrue('weight_column_name' in est.params)
|
||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
||||
scores = est.evaluate(
|
||||
input_fn=boston_input_with_weight_fn, steps=100,
|
||||
metrics={'SEWS': squared_error_weighted_sum,
|
||||
'SE': squared_error_no_weight})
|
||||
self.assertNear(scores['SEWS']*2, scores['SE'], 0.01)
|
||||
|
||||
def testEstimatorWithNoWeight(self):
|
||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_with_weights_fn)
|
||||
est.fit(input_fn=boston_input_with_weight_fn, steps=100)
|
||||
scores = est.evaluate(
|
||||
input_fn=boston_input_with_weight_fn, steps=100,
|
||||
metrics={'SEWS': squared_error_weighted_sum,
|
||||
'SE': squared_error_no_weight})
|
||||
self.assertNear(scores['SEWS'], scores['SE'], 0.01)
|
||||
|
||||
def testMonitors(self):
|
||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
|
||||
est.fit(input_fn=boston_input_fn,
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import target_column
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
@ -70,7 +71,7 @@ def _wrap_metric(metric):
|
||||
targets = math_ops.cast(targets, preds.dtype)
|
||||
return metric(preds, targets)
|
||||
|
||||
def wrapped_weights(preds, targets, weights):
|
||||
def wrapped_weights(preds, targets, weights=None):
|
||||
targets = math_ops.cast(targets, preds.dtype)
|
||||
if weights is not None:
|
||||
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
|
||||
@ -264,6 +265,7 @@ def sdca_classifier_model_fn(features, targets, mode, params):
|
||||
loss = None
|
||||
if mode != estimator.ModeKeys.INFER:
|
||||
loss = math_ops.reduce_mean(loss_fn(logits, targets), name="loss")
|
||||
logging_ops.scalar_summary("loss", loss)
|
||||
|
||||
train_op = None
|
||||
if mode == estimator.ModeKeys.TRAIN:
|
||||
@ -347,8 +349,6 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
Both features' `value` must be a `SparseTensor`.
|
||||
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
|
||||
whose `value` is a `Tensor`.
|
||||
- if `feature_columns` is `None`, then `input` must contains only real
|
||||
valued `Tensor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -426,8 +426,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
model_fn=model_fn,
|
||||
model_dir=self._model_dir,
|
||||
config=config,
|
||||
params=params,
|
||||
weight_column_name=weight_column_name)
|
||||
params=params)
|
||||
|
||||
def get_estimator(self):
|
||||
return self._estimator
|
||||
@ -445,14 +444,24 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
"""See evaluable.Evaluable."""
|
||||
if not metrics:
|
||||
metrics = {}
|
||||
metrics[("accuracy", _CLASSES)] = metrics_lib.streaming_accuracy
|
||||
metrics["accuracy"] = metric_spec.MetricSpec(
|
||||
metric_fn=metrics_lib.streaming_accuracy,
|
||||
prediction_key=_CLASSES)
|
||||
if self._n_classes == 2:
|
||||
additional_metrics = (
|
||||
target_column.get_default_binary_metrics_for_eval([0.5]))
|
||||
additional_metrics = {(name, _LOGISTIC): metric
|
||||
for name, metric in additional_metrics.items()}
|
||||
additional_metrics = {
|
||||
name: metric_spec.MetricSpec(metric_fn=metric,
|
||||
prediction_key=_LOGISTIC)
|
||||
for name, metric in additional_metrics.items()
|
||||
}
|
||||
metrics.update(additional_metrics)
|
||||
|
||||
# TODO(b/31229024): Remove this loop
|
||||
for metric_name, metric in metrics.items():
|
||||
if isinstance(metric, metric_spec.MetricSpec):
|
||||
continue
|
||||
|
||||
if isinstance(metric_name, tuple):
|
||||
if len(metric_name) != 2:
|
||||
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
||||
@ -577,8 +586,6 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
||||
key=weight column name, value=a `SparseTensor`}
|
||||
- if isinstance(column, `RealValuedColumn`):
|
||||
key=column.name, value=a `Tensor`
|
||||
- if `feature_columns` is `None`:
|
||||
input must contains only real valued `Tensor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
|
||||
|
||||
def _iris_input_fn():
|
||||
@ -137,8 +138,8 @@ class LinearClassifierTest(tf.test.TestCase):
|
||||
|
||||
def _input_fn_train():
|
||||
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
|
||||
target = tf.constant([[1], [0], [0], [0]])
|
||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
|
||||
target = tf.constant([[1], [0], [0], [0]], dtype=tf.float32)
|
||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32)}
|
||||
return features, target
|
||||
|
||||
def _my_metric_op(predictions, targets):
|
||||
@ -155,9 +156,14 @@ class LinearClassifierTest(tf.test.TestCase):
|
||||
input_fn=_input_fn_train,
|
||||
steps=100,
|
||||
metrics={
|
||||
('my_accuracy', 'classes'): tf.contrib.metrics.streaming_accuracy,
|
||||
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
|
||||
('my_metric', 'probabilities'): _my_metric_op
|
||||
'my_accuracy': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key='classes'),
|
||||
'my_precision': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key='classes'),
|
||||
'my_metric': MetricSpec(metric_fn=_my_metric_op,
|
||||
prediction_key='probabilities')
|
||||
})
|
||||
self.assertTrue(
|
||||
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
||||
|
@ -25,13 +25,12 @@ from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.layers.python.layers import target_column
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def _as_iterable(preds, output):
|
||||
@ -47,21 +46,6 @@ def _get_metric_args(metric):
|
||||
if arg not in metric.keywords.keys()]
|
||||
|
||||
|
||||
def _wrap_metric(metric):
|
||||
"""Wraps metrics for mismatched prediction/target types."""
|
||||
def wrapped(preds, targets):
|
||||
targets = math_ops.cast(targets, preds.dtype)
|
||||
return metric(preds, targets)
|
||||
|
||||
def wrapped_weights(preds, targets, weights):
|
||||
targets = math_ops.cast(targets, preds.dtype)
|
||||
if weights is not None:
|
||||
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
|
||||
return metric(preds, targets, weights)
|
||||
|
||||
return wrapped_weights if "weights" in _get_metric_args(metric) else wrapped
|
||||
|
||||
|
||||
class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
"""Support Vector Machine (SVM) model for binary classification.
|
||||
|
||||
@ -100,9 +84,6 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
whose `value` is a `SparseTensor`.
|
||||
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
|
||||
whose `value` is a `Tensor`.
|
||||
- if `feature_columns` is None, then `input` must contains only real
|
||||
valued `Tensor`.
|
||||
|
||||
|
||||
Parameters:
|
||||
example_id_column: A string defining the feature column name representing
|
||||
@ -166,15 +147,24 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
batch_size=None, steps=None, metrics=None, name=None):
|
||||
"""See evaluable.Evaluable."""
|
||||
if not metrics:
|
||||
metrics = {
|
||||
("accuracy", linear._CLASSES): metrics_lib.streaming_accuracy,
|
||||
}
|
||||
metrics = {}
|
||||
metrics["accuracy"] = metric_spec.MetricSpec(
|
||||
metric_fn=metrics_lib.streaming_accuracy,
|
||||
prediction_key=linear._CLASSES)
|
||||
additional_metrics = (
|
||||
target_column.get_default_binary_metrics_for_eval([0.5]))
|
||||
additional_metrics = {(name, linear._LOGISTIC): metric
|
||||
for name, metric in additional_metrics.items()}
|
||||
additional_metrics = {
|
||||
name: metric_spec.MetricSpec(metric_fn=metric,
|
||||
prediction_key=linear._LOGISTIC)
|
||||
for name, metric in additional_metrics.items()
|
||||
}
|
||||
metrics.update(additional_metrics)
|
||||
|
||||
# TODO(b/31229024): Remove this loop
|
||||
for metric_name, metric in metrics.items():
|
||||
if isinstance(metric, metric_spec.MetricSpec):
|
||||
continue
|
||||
|
||||
if isinstance(metric_name, tuple):
|
||||
if len(metric_name) != 2:
|
||||
raise ValueError("Ignoring metric %s. It returned a tuple with len "
|
||||
@ -184,7 +174,7 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
if metric_name[1] not in valid_keys:
|
||||
raise ValueError("Ignoring metric %s. The 2nd element of its name "
|
||||
"should be in %s" % (metric_name, valid_keys))
|
||||
metrics[metric_name] = _wrap_metric(metric)
|
||||
metrics[metric_name] = linear._wrap_metric(metric)
|
||||
return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
|
||||
feed_fn=feed_fn, batch_size=batch_size,
|
||||
steps=steps, metrics=metrics, name=name)
|
||||
|
186
tensorflow/contrib/learn/python/learn/metric_spec.py
Normal file
186
tensorflow/contrib/learn/python/learn/metric_spec.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The metric spec class to flexibly connect models and metrics."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
class MetricSpec(object):
|
||||
"""MetricSpec connects a model to metric functions.
|
||||
|
||||
The MetricSpec class contains all information necessary to connect the
|
||||
output of a `model_fn` to the metrics (usually, streaming metrics) that are
|
||||
used in evaluation.
|
||||
|
||||
It is passed in the `metrics` argument of `Estimator.evaluate`. The
|
||||
`Estimator` then knows which predictions, labels, and weight to use to call a
|
||||
given metric function.
|
||||
|
||||
When building the ops to run in evaluation, `Estimator` will call
|
||||
`create_metric_ops`, which will connect the given `metric_fn` to the model
|
||||
as detailed in the docstring for `create_metric_ops`, and return the metric.
|
||||
|
||||
Example:
|
||||
|
||||
Assuming an model has an input function which returns inputs containing
|
||||
(among other things) a tensor with key "income", and a labels dictionary
|
||||
containing "has_clicked". Let's assume that the `model_fn` for this model
|
||||
returns a prediction with key "clicked".
|
||||
|
||||
In order to compute the accuracy of the "clicked" prediction, we would add
|
||||
```
|
||||
"click accuracy": MetricSpec(metric_fn=streaming_accuracy,
|
||||
prediction_key="clicked",
|
||||
label_key="has_clicked")
|
||||
```
|
||||
to the metrics argument to `evaluate`. If we would like the accuracy to be
|
||||
weighted by "income", we can add that as the `weight_key` argument.
|
||||
```
|
||||
"click accuracy": MetricSpec(metric_fn=streaming_accuracy,
|
||||
prediction_key="clicked",
|
||||
label_key="has_clicked",
|
||||
weight_key="income")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metric_fn,
|
||||
prediction_key=None,
|
||||
label_key=None,
|
||||
weight_key=None):
|
||||
"""Constructor.
|
||||
|
||||
Creates a MetricSpec.
|
||||
|
||||
Args:
|
||||
metric_fn: A function to use as a metric. Must accept `predictions`,
|
||||
`labels` and optionally, `weights` tensors as inputs, and must return
|
||||
either a single tensor which is interpreted as a value of this metric,
|
||||
or a pair `(value_op, update_op)`, where value_op is the op to call to
|
||||
obtain the value of the metric, and update_op should be evaluated for
|
||||
each batch in order to update internal state.
|
||||
prediction_key: The key for a tensor in the `predictions` dict (output
|
||||
from the `model_fn`) to use as the `predictions` input to the
|
||||
`metric_fn`. Optional. If `None`, the `model_fn` must return a single
|
||||
tensor or a dict with only a single entry as `predictions`.
|
||||
label_key: The key for a tensor in the `labels` dict (output from the
|
||||
`input_fn`) to use as the `labels` input to the `metric_fn`.
|
||||
Optional. If `None`, the `input_fn` must return a single tensor or a
|
||||
dict with only a single entry as `labels`.
|
||||
weight_key: The key for a tensor in the `inputs` dict (output from the
|
||||
`input_fn`) to use as the `weights` input to the `metric_fn`.
|
||||
Optional. If `None`, no weights will be passed to the `metric_fn`.
|
||||
"""
|
||||
self._metric_fn = metric_fn
|
||||
self._prediction_key = prediction_key
|
||||
self._label_key = label_key
|
||||
self._weight_key = weight_key
|
||||
|
||||
@property
|
||||
def prediction_key(self):
|
||||
return self._prediction_key
|
||||
|
||||
@property
|
||||
def label_key(self):
|
||||
return self._label_key
|
||||
|
||||
@property
|
||||
def weight_key(self):
|
||||
return self._weight_key
|
||||
|
||||
@property
|
||||
def metric_fn(self):
|
||||
return self._metric_fn
|
||||
|
||||
def __str__(self):
|
||||
return ('MetricSpec(metric_fn=%s, ' % self.metric_fn.__name__ +
|
||||
'prediction_key=%s, ' % self.prediction_key +
|
||||
'label_key=%s, ' % self.label_key +
|
||||
'weight_key=%s)' % self.weight_key
|
||||
)
|
||||
|
||||
def create_metric_ops(self, inputs, labels, predictions):
|
||||
"""Connect our `metric_fn` to the specified members of the given dicts.
|
||||
|
||||
This function will call the `metric_fn` given in our constructor as follows:
|
||||
```
|
||||
metric_fn(predictions[self.prediction_key],
|
||||
labels[self.label_key],
|
||||
weights=weights[self.weight_key])
|
||||
```
|
||||
And returns the result. The `weights` argument is only passed if
|
||||
`self.weight_key` is not `None`.
|
||||
|
||||
`predictions` and `labels` may be single tensors as well as dicts. If
|
||||
`predictions` is a single tensor, `self.prediction_key` must be `None`. If
|
||||
`predictions` is a single element dict, `self.prediction_key` is allowed to
|
||||
be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must
|
||||
be `None`. If `labels` is a single element dict, `self.label_key` is allowed
|
||||
to be `None`.
|
||||
|
||||
Args:
|
||||
inputs: A dict of inputs produced by the `input_fn`
|
||||
labels: A dict of labels or a single label tensor produced by the
|
||||
`input_fn`.
|
||||
predictions: A dict of predictions or a single tensor produced by the
|
||||
`model_fn`.
|
||||
|
||||
Returns:
|
||||
The result of calling `metric_fn`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` or `labels` is a single `Tensor` and
|
||||
`self.prediction_key` or `self.label_key` is not `None`; or if
|
||||
`self.label_key` is `None` but `labels` is a dict with more than one
|
||||
element, or if `self.prediction_key` is `None but `predictions` is a
|
||||
dict with more than one element.
|
||||
"""
|
||||
def _get_dict(name, dict_or_tensor, key):
|
||||
"""Get a single tensor or an element of a dict or raise ValueError."""
|
||||
if key:
|
||||
if not isinstance(dict_or_tensor, dict):
|
||||
raise ValueError('MetricSpec with ' + name + '_key specified'
|
||||
' requires ' +
|
||||
name + 's dict, got %s' % dict_or_tensor)
|
||||
return dict_or_tensor[key]
|
||||
else:
|
||||
if isinstance(dict_or_tensor, dict):
|
||||
if len(dict_or_tensor) != 1:
|
||||
raise ValueError('MetricSpec without specified ' + name + '_key'
|
||||
' requires ' + name + 's tensor or single element'
|
||||
' dict, got %s' % dict_or_tensor)
|
||||
return dict_or_tensor.values()[0]
|
||||
else:
|
||||
return dict_or_tensor
|
||||
|
||||
# Get the predictions
|
||||
prediction = _get_dict('prediction', predictions, self.prediction_key)
|
||||
|
||||
# Get the labels
|
||||
label = _get_dict('label', labels, self.label_key)
|
||||
|
||||
try:
|
||||
if self.weight_key:
|
||||
return self.metric_fn(prediction, label,
|
||||
weights=inputs[self.weight_key])
|
||||
else:
|
||||
return self.metric_fn(prediction, label)
|
||||
except: # pylint: disable=bare-except
|
||||
logging.error('Could not create metric ops for %s.' % self)
|
||||
raise
|
150
tensorflow/contrib/learn/python/learn/tests/metric_spec_test.py
Normal file
150
tensorflow/contrib/learn/python/learn/tests/metric_spec_test.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for MetricSpec."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
|
||||
|
||||
def test_metric(predictions, labels, weights=None):
|
||||
return predictions, labels, weights
|
||||
|
||||
|
||||
class MetricSpecTest(tf.test.TestCase):
|
||||
|
||||
def test_create_metric_ops(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
passed = MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops(features,
|
||||
labels,
|
||||
predictions)
|
||||
|
||||
self.assertEqual(passed[0], "pred1_tensor")
|
||||
self.assertEqual(passed[1], "label1_tensor")
|
||||
self.assertEqual(passed[2], "feature2_tensor")
|
||||
|
||||
def test_no_weight(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
passed = MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
label_key="label1").create_metric_ops(features, labels,
|
||||
predictions)
|
||||
|
||||
self.assertEqual(passed[0], "pred1_tensor")
|
||||
self.assertEqual(passed[1], "label1_tensor")
|
||||
self.assertEqual(passed[2], None)
|
||||
|
||||
def test_fail_no_prediction(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec without specified prediction_key "
|
||||
"requires predictions tensor or single element "
|
||||
"dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops,
|
||||
features, labels, predictions)
|
||||
|
||||
def test_fail_no_label(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec without specified label_key requires "
|
||||
"labels tensor or single element dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
weight_key="feature2").create_metric_ops,
|
||||
features, labels, predictions)
|
||||
|
||||
def test_single_prediction(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = "pred1_tensor"
|
||||
|
||||
passed = MetricSpec(metric_fn=test_metric,
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops(features,
|
||||
labels,
|
||||
predictions)
|
||||
|
||||
self.assertEqual(passed[0], "pred1_tensor")
|
||||
self.assertEqual(passed[1], "label1_tensor")
|
||||
self.assertEqual(passed[2], "feature2_tensor")
|
||||
|
||||
def test_single_label(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = "label1_tensor"
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
passed = MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
weight_key="feature2").create_metric_ops(features,
|
||||
labels,
|
||||
predictions)
|
||||
|
||||
self.assertEqual(passed[0], "pred1_tensor")
|
||||
self.assertEqual(passed[1], "label1_tensor")
|
||||
self.assertEqual(passed[2], "feature2_tensor")
|
||||
|
||||
def test_fail_single_prediction(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = {"label1": "label1_tensor", "label2": "label2_tensor"}
|
||||
predictions = "pred1_tensor"
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec with prediction_key specified requires "
|
||||
"predictions dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops,
|
||||
features, labels, predictions)
|
||||
|
||||
def test_fail_single_label(self):
|
||||
features = {"feature1": "feature1_tensor", "feature2": "feature2_tensor"}
|
||||
labels = "label1_tensor"
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec with label_key specified requires "
|
||||
"labels dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops,
|
||||
features, labels, predictions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -367,6 +367,7 @@ class SdcaModel(object):
|
||||
|
||||
logging_ops.scalar_summary('approximate_duality_gap',
|
||||
self.approximate_duality_gap())
|
||||
logging_ops.scalar_summary('examples_seen', self._hashtable.size())
|
||||
|
||||
def _symmetric_l1_regularization(self):
|
||||
return self._options['symmetric_l1_regularization']
|
||||
|
@ -116,6 +116,7 @@ weighted average over the individual prediction errors:
|
||||
@@mean_squared_error
|
||||
@@sigmoid_cross_entropy
|
||||
@@softmax_cross_entropy
|
||||
@@sparse_softmax_cross_entropy
|
||||
|
||||
The following are deprecated in favor of `mean_pairwise_squared_error` and
|
||||
`mean_squared_error`.
|
||||
|
@ -41,6 +41,7 @@ __all__ = ["absolute_difference",
|
||||
"mean_squared_error",
|
||||
"sigmoid_cross_entropy",
|
||||
"softmax_cross_entropy",
|
||||
"sparse_softmax_cross_entropy",
|
||||
"sum_of_pairwise_squares",
|
||||
"sum_of_squares"]
|
||||
|
||||
@ -354,8 +355,8 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid or if `weight` is None.
|
||||
ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
|
||||
or if the shape of `weight` is invalid or if `weight` is None.
|
||||
"""
|
||||
with ops.name_scope(scope, "softmax_cross_entropy_loss",
|
||||
[logits, onehot_labels]):
|
||||
@ -375,6 +376,39 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None):
|
||||
"""Cross-entropy loss using tf.nn.sparse_softmax_cross_entropy_with_logits.
|
||||
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided,
|
||||
then the loss is simply scaled by the given value. If `weight` is a
|
||||
tensor of size [`batch_size`], then the loss weights apply to each
|
||||
corresponding sample.
|
||||
|
||||
Args:
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or
|
||||
`int64` in the range `[0, num_classes)`.
|
||||
weight: Coefficients for the loss. The tensor must be a scalar or a tensor
|
||||
of shape [batch_size] or [batch_size, 1].
|
||||
scope: the scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shapes of logits, labels, and weight are incompatible, or
|
||||
if `weight` is None.
|
||||
"""
|
||||
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
||||
[logits, labels]):
|
||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||
weight = array_ops.squeeze(weight)
|
||||
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
|
||||
name="xentropy")
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
|
||||
"""Adds a Log Loss term to the training procedure.
|
||||
|
||||
|
@ -173,7 +173,7 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||
|
||||
def testAllWrongAllMissing(self):
|
||||
def testAllWrongAllWeightsMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -185,7 +185,7 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testSomeMissing(self):
|
||||
def testSomeWeightsMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -235,6 +235,216 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||
|
||||
def testNoneWeightRaisesValueError(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0], [1], [2]])
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=None)
|
||||
|
||||
def testAllCorrectInt32Labels(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0], [1], [2]], dtype=tf.int32)
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testAllCorrectInt64Labels(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0], [1], [2]], dtype=tf.int64)
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testAllCorrectNonColumnLabels(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([0, 1, 2])
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testAllWrongInt32Labels(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]], dtype=tf.int32)
|
||||
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
def testAllWrongInt64Labels(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]], dtype=tf.int64)
|
||||
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
def testAllWrongNonColumnLabels(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([2, 0, 1])
|
||||
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = 2.3
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = 2.3
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, tf.constant(weight))
|
||||
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = tf.constant([1.2, 3.4, 5.6], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||
|
||||
def testNonZeroLossWithColumnWeights(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = tf.constant([[1.2], [3.4], [5.6]])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||
|
||||
def testAllWrongAllWeightsMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = tf.constant([0, 0, 0], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testSomeWeightsMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[2], [0], [1]])
|
||||
weight = tf.constant([1.2, 0, 0], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), 12.0, 3)
|
||||
|
||||
def testMeasurementSpecificWeightsRaisesException(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0], [1], [2]])
|
||||
weight = tf.constant([[3, 4, 5],
|
||||
[2, 6, 0],
|
||||
[8, 0, 1]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
def testInconsistentWeightSizeRaisesException(self):
|
||||
"""The weight tensor has incorrect number of elements."""
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0], [1], [2]])
|
||||
weight = tf.constant([1.2, 3.4, 5.6, 7.8])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
def testInconsistentLabelSizeRaisesException(self):
|
||||
"""The label tensor has incorrect number of elements."""
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0], [1], [2], [3]])
|
||||
weight = tf.constant([1.2, 3.4, 5.6])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
def testInconsistentWeightShapeRaisesException(self):
|
||||
"""The weight tensor has incorrect shape."""
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0], [1], [2], [3]])
|
||||
weight = tf.constant([[1.2, 3.4], [5.6, 7.8]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
def testInconsistentLabelShapeRaisesException(self):
|
||||
"""The label tensor has incorrect shape."""
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0, 1], [2, 3]])
|
||||
weight = tf.constant([1.2, 3.4, 5.6, 7.8])
|
||||
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
tf.contrib.losses.sparse_softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
|
||||
class SigmoidCrossEntropyLossTest(tf.test.TestCase):
|
||||
|
||||
def testAllCorrectSigmoid(self):
|
||||
|
@ -419,6 +419,7 @@ $(wildcard tensorflow/core/graph/*.cc) \
|
||||
$(wildcard tensorflow/core/lib/*/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*/*/*.cc) \
|
||||
$(wildcard tensorflow/core/util/*.cc) \
|
||||
$(wildcard tensorflow/core/util/*/*.cc)
|
||||
CORE_CC_EXCLUDE_SRCS := \
|
||||
|
@ -25,7 +25,7 @@ tensorflow/core/lib/random/simple_philox.cc
|
||||
tensorflow/core/lib/random/random.cc
|
||||
tensorflow/core/lib/random/distribution_sampler.cc
|
||||
tensorflow/core/lib/io/zlib_outputbuffer.cc
|
||||
tensorflow/core/lib/io/zlib_inputbuffer.cc
|
||||
tensorflow/core/lib/io/zlib_inputstream.cc
|
||||
tensorflow/core/lib/io/two_level_iterator.cc
|
||||
tensorflow/core/lib/io/table_builder.cc
|
||||
tensorflow/core/lib/io/table.cc
|
||||
|
@ -76,6 +76,7 @@ tensorflow/core/kernels/example_parsing_ops.cc
|
||||
tensorflow/core/kernels/dynamic_stitch_op.cc
|
||||
tensorflow/core/kernels/dynamic_partition_op.cc
|
||||
tensorflow/core/kernels/dense_update_ops.cc
|
||||
tensorflow/core/kernels/deep_conv2d.cc
|
||||
tensorflow/core/kernels/cwise_ops_common.cc
|
||||
tensorflow/core/kernels/cwise_op_tanh.cc
|
||||
tensorflow/core/kernels/cwise_op_sub.cc
|
||||
@ -100,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc
|
||||
tensorflow/core/kernels/cwise_op_add.cc
|
||||
tensorflow/core/kernels/ctc_decoder_ops.cc
|
||||
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
||||
tensorflow/core/kernels/conv_ops_fused.cc
|
||||
tensorflow/core/kernels/conv_ops.cc
|
||||
tensorflow/core/kernels/conv_grad_ops.cc
|
||||
tensorflow/core/kernels/control_flow_ops.cc
|
||||
|
@ -127,7 +127,6 @@ time.
|
||||
|
||||
@@aggregate_metrics
|
||||
@@aggregate_metric_map
|
||||
@@run_metric
|
||||
|
||||
## Set `Ops`
|
||||
|
||||
@ -147,7 +146,6 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
|
||||
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import run_metric
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean
|
||||
|
@ -213,7 +213,7 @@ void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
|
||||
result->clear();
|
||||
auto input_flat = input_tensor.flat<T>();
|
||||
const auto start = std::inner_product(
|
||||
group_indices.begin(), group_indices.end(), input_strides.begin(), 0);
|
||||
group_indices.begin(), group_indices.end(), input_strides.begin(), 0L);
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
|
||||
for (int64 i = start; i < end; ++i) {
|
||||
@ -273,7 +273,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
|
||||
|
||||
const auto group_key = group.group();
|
||||
const auto output_index = std::inner_product(
|
||||
group_key.begin(), group_key.end(), output_strides.begin(), 0);
|
||||
group_key.begin(), group_key.end(), output_strides.begin(), 0L);
|
||||
out(output_index) = group_set.size();
|
||||
}
|
||||
}
|
||||
@ -441,7 +441,7 @@ void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
|
||||
|
||||
std::set<T> group_set;
|
||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||
if (group_set.size() > 0) {
|
||||
if (!group_set.empty()) {
|
||||
group_sets[group_indices] = group_set;
|
||||
const auto set_size = group_set.size();
|
||||
if (set_size > max_set_size) {
|
||||
@ -516,7 +516,7 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
|
||||
|
||||
std::set<T> group_set;
|
||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||
if (group_set.size() > 0) {
|
||||
if (!group_set.empty()) {
|
||||
group_sets[group_indices] = group_set;
|
||||
const auto set_size = group_set.size();
|
||||
if (set_size > max_set_size) {
|
||||
@ -632,7 +632,7 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
|
||||
|
||||
std::set<T> group_set;
|
||||
ApplySetOperation(set1_group_set, set2_group_set, &group_set);
|
||||
if (group_set.size() > 0) {
|
||||
if (!group_set.empty()) {
|
||||
group_sets[*group_indices] = group_set;
|
||||
const auto set_size = group_set.size();
|
||||
if (set_size > max_set_size) {
|
||||
|
@ -121,7 +121,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
||||
predictions = np.asarray([1, 2, 3])
|
||||
labels = np.asarray([1, 2])
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "are not compatible",
|
||||
ValueError, "must be equal",
|
||||
tf.contrib.metrics.confusion_matrix, predictions, labels)
|
||||
|
||||
def testOutputIsInt32(self):
|
||||
|
@ -22,8 +22,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
|
||||
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
||||
@ -467,6 +465,8 @@ def streaming_accuracy(predictions, labels, weights=None,
|
||||
predictions, labels = metric_ops_util.remove_squeezable_dimensions(
|
||||
predictions, labels)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
if labels.dtype != predictions.dtype:
|
||||
predictions = math_ops.cast(predictions, labels.dtype)
|
||||
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
|
||||
return streaming_mean(is_correct, weights, metrics_collections,
|
||||
updates_collections, name or 'accuracy')
|
||||
@ -2126,37 +2126,4 @@ def aggregate_metric_map(names_to_tuples):
|
||||
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
|
||||
|
||||
|
||||
def run_metric(metric, predictions, targets, weights=None):
|
||||
"""Runs a single metric.
|
||||
|
||||
This function runs metric on given predictions and targets. weights will be
|
||||
used if metric contains 'weights' in its argument.
|
||||
|
||||
Args:
|
||||
metric: A function that evaluates targets given predictions.
|
||||
predictions: A `Tensor` of arbitrary shape.
|
||||
targets: A `Tensor` of the same shape as `predictions`.
|
||||
weights: A set of weights that can be used in metric function to compute
|
||||
weighted result.
|
||||
|
||||
Returns:
|
||||
result: result returned by metric function.
|
||||
"""
|
||||
metric_args = []
|
||||
if hasattr(metric, '__code__'):
|
||||
# Regular function.
|
||||
metric_args = inspect.getargspec(metric).args
|
||||
elif hasattr(metric, 'func') and hasattr(metric, 'keywords'):
|
||||
# Partial function.
|
||||
for arg in inspect.getargspec(metric.func).args:
|
||||
if metric.keywords and arg not in metric.keywords.keys():
|
||||
metric_args.append(arg)
|
||||
if 'weights' in metric_args:
|
||||
result = metric(predictions, targets, weights=weights)
|
||||
else:
|
||||
result = metric(predictions, targets)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
@ -2851,37 +2850,5 @@ class AggregateMetricMapTest(tf.test.TestCase):
|
||||
self.assertEqual(4, names_to_values['m2'].eval())
|
||||
|
||||
|
||||
class RunMetricTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tf.reset_default_graph()
|
||||
|
||||
def testRunMetric(self):
|
||||
predictions = tf.constant([2, 4, 6, 8], shape=(1, 4), dtype=tf.float32)
|
||||
labels = tf.constant([1, 3, 2, 3], shape=(1, 4), dtype=tf.float32)
|
||||
weights = tf.constant([0, 1, 0, 1], shape=(1, 4))
|
||||
|
||||
error, update_op = metrics.run_metric(metrics.streaming_mean_squared_error,
|
||||
predictions, labels, weights)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertEqual(13, sess.run(update_op))
|
||||
self.assertEqual(13, error.eval())
|
||||
|
||||
def testRunMetricsWithOutWeights(self):
|
||||
predictions = tf.constant([2, 4, 6], shape=(1, 3), dtype=tf.float32)
|
||||
labels = tf.constant([1, 3, 2], shape=(1, 3), dtype=tf.float32)
|
||||
|
||||
streaming_mean_squared_error_no_weight = partial(
|
||||
metrics.streaming_mean_squared_error, weights=None)
|
||||
|
||||
error, update_op = metrics.run_metric(
|
||||
streaming_mean_squared_error_no_weight, predictions, labels)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertEqual(6, sess.run(update_op))
|
||||
self.assertEqual(6, error.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
#include "tensorflow/core/platform/hexagon/gemm_wrapper.h"
|
||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -49,6 +50,30 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
// Shows some statistics of hexagon dsp using hexagon specific APIs
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
|
||||
const uint64 overhead_shared_lib_start =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const int wrapper_version = hexagon_gemm_wrapper_GetWrapperVersion();
|
||||
const uint64 overhead_shared_lib_end =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const uint64 overhead_hexagon_rpc_start =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const int hexagon_binary_version =
|
||||
hexagon_gemm_wrapper_GetHexagonBinaryVersion();
|
||||
const uint64 overhead_hexagon_rpc_end =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
|
||||
<< (overhead_shared_lib_end - overhead_shared_lib_start)
|
||||
<< " cycles";
|
||||
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
|
||||
<< ") overhead is "
|
||||
<< (overhead_hexagon_rpc_end - overhead_hexagon_rpc_start)
|
||||
<< " cycles";
|
||||
}
|
||||
#endif
|
||||
|
||||
// Runs two small matrices through the operator, and leaves all the parameters
|
||||
// at their default values.
|
||||
// This test is a sample to execute matmul on hexagon.
|
||||
|
@ -28,6 +28,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/rnn_cell_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -39,6 +40,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/lstm_ops_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -82,6 +84,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/gru_ops_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
|
@ -11,6 +11,7 @@ py_library(
|
||||
name = "training_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/training/bucket_ops.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
],
|
||||
@ -67,6 +68,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bucket_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/training/bucket_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -38,6 +38,17 @@ balanced.
|
||||
|
||||
@@stratified_sample
|
||||
@@stratified_sample_unknown_dist
|
||||
|
||||
## Bucketing
|
||||
|
||||
Use ['bucket'](#bucket) or
|
||||
['bucket_by_sequence_length'](#bucket_by_sequence_length) to stratify
|
||||
minibatches into groups ("buckets"). Use `bucket_by_sequence_length`
|
||||
with the argument `dynamic_pad=True` to receive minibatches of similarly
|
||||
sized sequences for efficient training via `dynamic_rnn`.
|
||||
|
||||
@@bucket
|
||||
@@bucket_by_sequence_length
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -45,6 +56,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
@ -0,0 +1,374 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations for bucketing data into groups.
|
||||
|
||||
The classes and functions in this module are used to queue up data into
|
||||
buckets conditional on side information (e.g. sequence length).
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import input as input_py
|
||||
from tensorflow.python.training import queue_runner
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_as_original_type = input_py._as_original_type
|
||||
_as_tensor_list = input_py._as_tensor_list
|
||||
_deserialize_sparse_tensors = input_py._deserialize_sparse_tensors
|
||||
_dtypes = input_py._dtypes
|
||||
_serialize_sparse_tensors = input_py._serialize_sparse_tensors
|
||||
_shapes = input_py._shapes
|
||||
_which_queue = input_py._which_queue
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _validate_bucket(tensor_list):
|
||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
||||
if not tensor_list:
|
||||
raise ValueError("Expected at least one tensor in bucket().")
|
||||
return tensor_list
|
||||
|
||||
|
||||
def bucket(tensors,
|
||||
which_bucket,
|
||||
batch_size,
|
||||
num_buckets,
|
||||
num_threads=1,
|
||||
capacity=32,
|
||||
shapes=None,
|
||||
dynamic_pad=False,
|
||||
allow_smaller_final_batch=False,
|
||||
keep_input=None,
|
||||
shared_name=None,
|
||||
name=None):
|
||||
"""Lazy bucketing of input tensors according to `which_bucket`.
|
||||
|
||||
The argument `tensors` can be a list or a dictionary of tensors.
|
||||
The value returned by the function will be of the same type
|
||||
as `tensors`.
|
||||
|
||||
The tensors entering this function are put into the bucket given by
|
||||
`which_bucket`. Each bucket has its own queue. When a bucket contains
|
||||
`batch_size` elements, this minibatch is pushed onto a top queue. The
|
||||
tensors returned from this function are a the result of dequeueing the
|
||||
next minibatch from this top queue.
|
||||
|
||||
This function is implemented using several queues. A `QueueRunner` for the
|
||||
queues is added to the current `Graph`'s `QUEUE_RUNNER` collection.
|
||||
|
||||
As the returned tensors are the result of of a dequeue operation, evaluating
|
||||
them will throw a `tf.errors.OutOfRangeError` when the input queue is
|
||||
exhausted. If these tensors are feeding another input queue, its queue runner
|
||||
will catch this exception, however, if they are used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
*N.B.:* If `dynamic_pad` is `False`, you must ensure that either
|
||||
(i) the `shapes` argument is passed, or (ii) all of the tensors in
|
||||
`tensors` must have fully-defined shapes. `ValueError` will be
|
||||
raised if neither of these conditions holds.
|
||||
|
||||
If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
|
||||
tensors is known, but individual dimensions may have shape `None`.
|
||||
In this case, for each enqueue the dimensions with value `None`
|
||||
may have a variable length; upon dequeue, the output tensors will be padded
|
||||
on the right to the maximum shape of the tensors in the current minibatch.
|
||||
For numbers, this padding takes value 0. For strings, this padding is
|
||||
the empty string. See `PaddingFIFOQueue` for more info.
|
||||
|
||||
If `allow_smaller_final_batch` is `True`, a smaller batch value than
|
||||
`batch_size` is returned when the queues are closed and there are not enough
|
||||
elements to fill the batch, otherwise the pending elements are discarded.
|
||||
In addition, all output tensors' static shapes, as accessed via the
|
||||
`get_shape()` method will have a 0th `Dimension` value of `None`, and
|
||||
operations that depend on fixed batch_size would fail.
|
||||
|
||||
Args:
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
num_buckets: A python integer, the number of buckets.
|
||||
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||
and also the maximum number of elements within each bucket.
|
||||
shapes: (Optional) The shapes for each example. Defaults to the
|
||||
inferred shapes for `tensors`.
|
||||
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||
The given dimensions are padded upon dequeue so that tensors within a
|
||||
batch have the same shapes.
|
||||
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||
batches to be smaller if there are insufficient items left in the queues.
|
||||
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||
controls whether the input is added to the queue or not. If it evaluates
|
||||
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||
dropped. This tensor essentially acts as a filtering mechanism.
|
||||
The default behavior is to assume `keep_input=True`.
|
||||
shared_name: (Optional). If set, the queues will be shared under the given
|
||||
name across multiple sessions.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A tuple `(bucket, outputs)` where `bucket` is
|
||||
a `int32` scalar tensor and `outputs` is a list or
|
||||
dictionary of batched outputs corresponding to elements of `tensors`.
|
||||
Every step will receive a new bucket of outputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `shapes` are not specified, and cannot be
|
||||
inferred from the elements of `tensors`.
|
||||
"""
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
with ops.name_scope(name, "bucket", tensor_list) as name:
|
||||
tensor_list = _validate_bucket(tensor_list)
|
||||
(tensor_list, sparse_info) = _serialize_sparse_tensors(
|
||||
tensor_list, enqueue_many=False)
|
||||
|
||||
# Round-trip batch_size to a tensor, and possibly back
|
||||
batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int32, name="batch_size")
|
||||
static_batch_size = tensor_util.constant_value(batch_size)
|
||||
batch_size = (
|
||||
static_batch_size if static_batch_size is not None else batch_size)
|
||||
|
||||
types = _dtypes([tensor_list])
|
||||
shapes = _shapes([tensor_list], shapes, enqueue_many=False)
|
||||
|
||||
which_bucket = ops.convert_to_tensor(
|
||||
which_bucket, dtype=dtypes.int32, name="which_bucket")
|
||||
|
||||
queue_creator = _which_queue(dynamic_pad)
|
||||
bucket_queues = []
|
||||
for i in range(num_buckets):
|
||||
shared_name_i = (
|
||||
"%s_%d" % (shared_name, i) if shared_name is not None else None)
|
||||
bucket_queues.append(
|
||||
queue_creator(capacity=capacity,
|
||||
dtypes=types,
|
||||
shapes=shapes,
|
||||
shared_name=shared_name_i, name="bucket_queue_%d" % i))
|
||||
|
||||
maybe_static_batch_size = (
|
||||
None if allow_smaller_final_batch else static_batch_size)
|
||||
|
||||
bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s)
|
||||
for s in bucket_queues[0].shapes]
|
||||
# top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
|
||||
# queues because if we use allow_smaller_final_batch, shapes will
|
||||
# contain Nones in their first entry; as a result, a regular
|
||||
# FIFOQueue would die when being passed shapes that are not fully defined.
|
||||
top_queue = data_flow_ops.PaddingFIFOQueue(
|
||||
capacity=capacity,
|
||||
dtypes=[dtypes.int32] + types,
|
||||
shapes=[tensor_shape.scalar()] + bucket_shapes,
|
||||
shared_name=shared_name, name="top_queue")
|
||||
|
||||
def enqueue_which():
|
||||
def enqueue_single(i):
|
||||
return bucket_queues[i].enqueue(tensor_list)
|
||||
enqueues = [
|
||||
control_flow_ops.cond(
|
||||
math_ops.equal(which_bucket, i),
|
||||
functools.partial(enqueue_single, i),
|
||||
control_flow_ops.no_op)
|
||||
for i in range(num_buckets)]
|
||||
return control_flow_ops.group(*enqueues, name="group_enqueues")
|
||||
|
||||
if keep_input is not None:
|
||||
# TODO(ebrevdo): Expand keep_input param to core training
|
||||
# methods, and pipe through to _serialize_sparse_tensors; so
|
||||
# that expensive serialization is guarded by keep_input.
|
||||
maybe_enqueue = control_flow_ops.cond(
|
||||
keep_input,
|
||||
enqueue_which,
|
||||
control_flow_ops.no_op)
|
||||
else:
|
||||
maybe_enqueue = enqueue_which()
|
||||
|
||||
bucket_enqueue_ops = [maybe_enqueue] * num_threads
|
||||
|
||||
if allow_smaller_final_batch:
|
||||
which_dequeue = lambda q: q.dequeue_up_to
|
||||
else:
|
||||
which_dequeue = lambda q: q.dequeue_many
|
||||
|
||||
enqueues_to_top = [
|
||||
top_queue.enqueue(
|
||||
[constant_op.constant(i)] +
|
||||
which_dequeue(q)(batch_size, name="read_bucket_%d" % i),
|
||||
name="enqueue_from_bucket_%d" % i)
|
||||
for i, q in enumerate(bucket_queues)]
|
||||
|
||||
for i, q in enumerate(bucket_queues):
|
||||
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||
q, [enqueues_to_top[i]],
|
||||
queue_closed_exception_types=(
|
||||
errors.OutOfRangeError, errors.CancelledError)))
|
||||
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||
top_queue, bucket_enqueue_ops,
|
||||
queue_closed_exception_types=(
|
||||
errors.OutOfRangeError, errors.CancelledError)))
|
||||
|
||||
for q in bucket_queues:
|
||||
logging_ops.scalar_summary(
|
||||
"bucket/%s/size" % q.name,
|
||||
math_ops.cast(top_queue.size(), dtypes.float32))
|
||||
logging_ops.scalar_summary(
|
||||
"bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity),
|
||||
math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity))
|
||||
|
||||
dequeued = top_queue.dequeue(name="dequeue_top")
|
||||
which_bucket_dequeued = dequeued[0]
|
||||
dequeued = dequeued[1:]
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
|
||||
return (which_bucket_dequeued, _as_original_type(tensors, dequeued))
|
||||
|
||||
|
||||
def bucket_by_sequence_length(input_length,
|
||||
tensors,
|
||||
batch_size,
|
||||
bucket_boundaries,
|
||||
num_threads=1,
|
||||
capacity=32,
|
||||
shapes=None,
|
||||
dynamic_pad=False,
|
||||
allow_smaller_final_batch=False,
|
||||
keep_input=None,
|
||||
shared_name=None,
|
||||
name=None):
|
||||
"""Lazy bucketing of inputs according to their length.
|
||||
|
||||
This method calls `tf.contrib.training.bucket` under the hood, after first
|
||||
subdividing the bucket boundaries into separate buckets and identifying which
|
||||
bucket the given `input_length` belongs to. See the documentation for
|
||||
`which_bucket` for details of the other arguments.
|
||||
|
||||
Args:
|
||||
input_length: `int32` scalar `Tensor`, the sequence length of tensors.
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
bucket_boundaries: int list, increasing non-negative numbers.
|
||||
The edges of the buckets to use when bucketing tensors. Two extra buckets
|
||||
are created, one for `input_length < bucket_boundaries[0]` and
|
||||
one for `input_length >= bucket_boundaries[-1]`.
|
||||
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||
and also the maximum number of elements within each bucket.
|
||||
shapes: (Optional) The shapes for each example. Defaults to the
|
||||
inferred shapes for `tensors`.
|
||||
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||
The given dimensions are padded upon dequeue so that tensors within a
|
||||
batch have the same shapes.
|
||||
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||
batches to be smaller if there are insufficient items left in the queues.
|
||||
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||
controls whether the input is added to the queue or not. If it evaluates
|
||||
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||
dropped. This tensor essentially acts as a filtering mechanism.
|
||||
The default behavior is to assume `keep_input=True`.
|
||||
shared_name: (Optional). If set, the queues will be shared under the given
|
||||
name across multiple sessions.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A tuple `(sequence_length, outputs)` where `sequence_length` is
|
||||
a 1-D `Tensor` of size `batch_size` and `outputs` is a list or dictionary
|
||||
of batched, bucketed, outputs corresponding to elements of `tensors`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `bucket_boundaries` is not a list of python integers.
|
||||
ValueError: if `bucket_boundaries` is empty or contains non-increasing
|
||||
values.
|
||||
"""
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
if not isinstance(bucket_boundaries, (list, tuple)):
|
||||
raise TypeError(
|
||||
"bucket_boundaries must be a list or tuple, but received: %s"
|
||||
% bucket_boundaries)
|
||||
if not bucket_boundaries:
|
||||
raise ValueError("bucket_boundaries must not be empty")
|
||||
for (s, e) in zip(bucket_boundaries[:-1], bucket_boundaries[1:]):
|
||||
if not isinstance(s, int) or not isinstance(e, int):
|
||||
raise TypeError(
|
||||
"bucket boundaries must be integers, but saw: %s and %s" % (s, e))
|
||||
if s >= e:
|
||||
raise ValueError(
|
||||
"Buckets must contain sequential increasing lengths, but saw: "
|
||||
"%d before %d" % (s, e))
|
||||
|
||||
with ops.name_scope(name, "bucket_by_sequence_length",
|
||||
[input_length] + tensor_list) as name:
|
||||
input_length = ops.convert_to_tensor(
|
||||
input_length, dtype=dtypes.int32, name="input_length")
|
||||
# Bucketing conditions are:
|
||||
# l < b[0]
|
||||
# b[0] <= l < b[1]
|
||||
# b[1] <= l < b[2]
|
||||
# ...
|
||||
# b[N-2] <= l < b[N-1]
|
||||
# b[N-1] <= l
|
||||
# Equivalent to:
|
||||
# [-inf, b[0], b[1], ..., b[N-1]] <= l < [b[0], b[1], ..., b[N-1], inf]
|
||||
buckets_min = [np.iinfo(np.int32).min] + list(bucket_boundaries)
|
||||
buckets_max = list(bucket_boundaries) + [np.iinfo(np.int32).max]
|
||||
conditions_c = math_ops.logical_and(
|
||||
math_ops.less_equal(buckets_min, input_length),
|
||||
math_ops.less(input_length, buckets_max))
|
||||
which_bucket = math_ops.reduce_min(array_ops.where(conditions_c))
|
||||
which_bucket = math_ops.to_int32(which_bucket)
|
||||
|
||||
if shapes is not None:
|
||||
shapes = [tensor_shape.scalar()] + shapes
|
||||
|
||||
_, dequeued = bucket(
|
||||
tensors=[input_length] + tensor_list,
|
||||
which_bucket=which_bucket,
|
||||
batch_size=batch_size,
|
||||
num_buckets=len(bucket_boundaries) + 1,
|
||||
num_threads=num_threads,
|
||||
capacity=capacity,
|
||||
shapes=shapes,
|
||||
dynamic_pad=dynamic_pad,
|
||||
allow_smaller_final_batch=allow_smaller_final_batch,
|
||||
keep_input=keep_input,
|
||||
shared_name=shared_name)
|
||||
|
||||
return (dequeued[0], _as_original_type(tensors, dequeued[1:]))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"bucket",
|
||||
"bucket_by_sequence_length"
|
||||
]
|
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
@ -0,0 +1,356 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tf.contrib.training.bucket."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def _which_bucket(bucket_edges, v):
|
||||
"""Identify which bucket v falls into.
|
||||
|
||||
Args:
|
||||
bucket_edges: int array, bucket edges
|
||||
v: int scalar, index
|
||||
Returns:
|
||||
int scalar, the bucket.
|
||||
If v < bucket_edges[0], return 0.
|
||||
If bucket_edges[0] <= v < bucket_edges[1], return 1.
|
||||
...
|
||||
If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges).
|
||||
If v >= bucket_edges[-1], return len(bucket_edges) + 1
|
||||
"""
|
||||
v = np.asarray(v)
|
||||
full = [0] + bucket_edges
|
||||
found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0]
|
||||
if not found.size:
|
||||
return len(full)
|
||||
return found[0]
|
||||
|
||||
|
||||
class BucketTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.scalar_int_feed = tf.placeholder(tf.int32, ())
|
||||
self.unk_int64_feed = tf.placeholder(tf.int64, (None,))
|
||||
self.vec3_str_feed = tf.placeholder(tf.string, (3,))
|
||||
|
||||
self._coord = tf.train.Coordinator()
|
||||
# Make capacity very large so we can feed all the inputs in the
|
||||
# main thread without blocking
|
||||
input_queue = tf.PaddingFIFOQueue(
|
||||
5000,
|
||||
dtypes=[tf.int32, tf.int64, tf.string],
|
||||
shapes=[(), (None,), (3,)])
|
||||
|
||||
self._input_enqueue_op = input_queue.enqueue(
|
||||
(self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed))
|
||||
self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue()
|
||||
self._threads = None
|
||||
self._close_op = input_queue.close()
|
||||
self._sess = None
|
||||
|
||||
def enqueue_inputs(self, sess, feed_dict):
|
||||
sess.run(self._input_enqueue_op, feed_dict=feed_dict)
|
||||
|
||||
def start_queue_runners(self, sess):
|
||||
# Store session to be able to close inputs later
|
||||
if self._sess is None:
|
||||
self._sess = sess
|
||||
self._threads = tf.train.start_queue_runners(coord=self._coord)
|
||||
|
||||
def tearDown(self):
|
||||
if self._sess is not None:
|
||||
self._sess.run(self._close_op)
|
||||
self._coord.request_stop()
|
||||
self._coord.join(self._threads)
|
||||
|
||||
def testSingleBucket(self):
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=tf.constant(0),
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(32):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get a single minibatch
|
||||
bucketed_values = sess.run(bucketed_dynamic)
|
||||
|
||||
# (which_bucket, bucket_tensors).
|
||||
self.assertEqual(2, len(bucketed_values))
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values[1]))
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, bucketed_values[0])
|
||||
|
||||
expected_scalar_int = np.arange(32)
|
||||
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||
for i in range(32):
|
||||
expected_unk_int64[i, :i] = i
|
||||
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values[1][0])
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
which_bucket = (self.scalar_int % 2)
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=which_bucket,
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(64):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
bucketed_values_0 = sess.run(bucketed_dynamic)
|
||||
bucketed_values_1 = sess.run(bucketed_dynamic)
|
||||
|
||||
# (which_bucket, bucket_tensors).
|
||||
self.assertEqual(2, len(bucketed_values_0))
|
||||
self.assertEqual(2, len(bucketed_values_1))
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_0[1]))
|
||||
self.assertEqual(3, len(bucketed_values_1[1]))
|
||||
|
||||
# Figure out which output has the even values (there's
|
||||
# randomness due to the multithreaded nature of bucketing)
|
||||
if bucketed_values_0[0] % 2 == 1:
|
||||
bucketed_values_even, bucketed_values_odd = (
|
||||
bucketed_values_1, bucketed_values_0)
|
||||
else:
|
||||
bucketed_values_even, bucketed_values_odd = (
|
||||
bucketed_values_0, bucketed_values_1)
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, bucketed_values_even[0])
|
||||
self.assertAllEqual(1, bucketed_values_odd[0])
|
||||
|
||||
# Test the first bucket outputted, the events starting at 0
|
||||
expected_scalar_int = np.arange(0, 32 * 2, 2)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2*i] = 2*i
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values_even[1][0])
|
||||
self.assertAllEqual(expected_scalar_int,
|
||||
bucketed_values_even[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64,
|
||||
bucketed_values_even[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str,
|
||||
bucketed_values_even[1][2][resort])
|
||||
|
||||
# Test the second bucket outputted, the odds starting at 1
|
||||
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2*i + 1] = 2*i + 1
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values_odd[1][0])
|
||||
self.assertAllEqual(expected_scalar_int,
|
||||
bucketed_values_odd[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64,
|
||||
bucketed_values_odd[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str,
|
||||
bucketed_values_odd[1][2][resort])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
which_bucket = (self.scalar_int % 2)
|
||||
keep_input = tf.equal(which_bucket, 0)
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=which_bucket,
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
keep_input=keep_input,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(128):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
bucketed_values_even0 = sess.run(bucketed_dynamic)
|
||||
bucketed_values_even1 = sess.run(bucketed_dynamic)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, bucketed_values_even0[0])
|
||||
self.assertAllEqual(0, bucketed_values_even1[0])
|
||||
|
||||
# Merge their output for sorting and comparison
|
||||
bucketed_values_all_elem0 = np.concatenate(
|
||||
(bucketed_values_even0[1][0],
|
||||
bucketed_values_even1[1][0]))
|
||||
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 128, 2), sorted(bucketed_values_all_elem0))
|
||||
|
||||
|
||||
class BucketBySequenceLengthTest(tf.test.TestCase):
|
||||
|
||||
def _testBucketBySequenceLength(self, allow_small_batch):
|
||||
tf.reset_default_graph()
|
||||
|
||||
# All inputs must be identical lengths across tuple index.
|
||||
# The input reader will get input_length from the first tuple
|
||||
# entry.
|
||||
data_len = 4
|
||||
target_len = 3
|
||||
input_pairs = [
|
||||
(length,
|
||||
([np.int64(length)] * data_len,
|
||||
[str(length).encode("ascii")] * target_len))
|
||||
for length in (1, 3, 4, 5, 6, 10)]
|
||||
|
||||
lengths = tf.placeholder(tf.int32, ())
|
||||
data = tf.placeholder(tf.int64, (data_len,))
|
||||
targets = tf.placeholder(tf.string, (target_len,))
|
||||
|
||||
batch_size = 8
|
||||
bucket_boundaries = [3, 4, 5, 10]
|
||||
|
||||
# Make capacity very large so we can feed all the inputs in the
|
||||
# main thread without blocking
|
||||
input_queue = tf.FIFOQueue(
|
||||
5000, (tf.int32, tf.int64, tf.string),
|
||||
((), (data_len,), (target_len,)))
|
||||
input_enqueue_op = input_queue.enqueue((lengths, data, targets))
|
||||
lengths_t, data_t, targets_t = input_queue.dequeue()
|
||||
close_input_op = input_queue.close()
|
||||
|
||||
(out_lengths_t, data_and_targets_t) = (
|
||||
tf.contrib.training.bucket_by_sequence_length(
|
||||
input_length=lengths_t,
|
||||
tensors=[data_t, targets_t],
|
||||
batch_size=batch_size,
|
||||
bucket_boundaries=bucket_boundaries,
|
||||
allow_smaller_final_batch=allow_small_batch,
|
||||
num_threads=10))
|
||||
|
||||
expected_batch_size = None if allow_small_batch else batch_size
|
||||
self.assertEqual(out_lengths_t.get_shape().as_list(),
|
||||
[expected_batch_size])
|
||||
self.assertEqual(data_and_targets_t[0].get_shape().as_list(),
|
||||
[expected_batch_size, data_len])
|
||||
self.assertEqual(data_and_targets_t[1].get_shape().as_list(),
|
||||
[expected_batch_size, target_len])
|
||||
|
||||
def _read_test(sess):
|
||||
for _ in range(50):
|
||||
(out_lengths, (data, targets)) = sess.run(
|
||||
(out_lengths_t, data_and_targets_t))
|
||||
if allow_small_batch:
|
||||
self.assertEqual(data_len, data.shape[1])
|
||||
self.assertEqual(target_len, targets.shape[1])
|
||||
self.assertGreaterEqual(batch_size, out_lengths.shape[0])
|
||||
self.assertGreaterEqual(batch_size, data.shape[0])
|
||||
self.assertGreaterEqual(batch_size, targets.shape[0])
|
||||
else:
|
||||
self.assertEqual((batch_size, data_len), data.shape)
|
||||
self.assertEqual((batch_size, target_len), targets.shape)
|
||||
self.assertEqual((batch_size,), out_lengths.shape)
|
||||
for (lr, dr, tr) in zip(out_lengths, data, targets):
|
||||
# Make sure length matches data (here it's the same value)
|
||||
self.assertEqual(dr[0], lr)
|
||||
# Make sure data & targets match
|
||||
self.assertEqual(dr[0], int(tr[0].decode("ascii")))
|
||||
# Make sure for each row, data came from the same bucket.
|
||||
self.assertEqual(_which_bucket(bucket_boundaries, dr[0]),
|
||||
_which_bucket(bucket_boundaries, dr[1]))
|
||||
|
||||
with self.test_session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
|
||||
# Feed the inputs, then close the input thread.
|
||||
for _ in range(50 * batch_size + 100):
|
||||
which = random.randint(0, len(input_pairs) - 1)
|
||||
length, pair = input_pairs[which]
|
||||
sess.run(input_enqueue_op, feed_dict={
|
||||
lengths: length, data: pair[0], targets: pair[1]})
|
||||
sess.run(close_input_op)
|
||||
|
||||
# Start the queue runners
|
||||
threads = tf.train.start_queue_runners(coord=coord)
|
||||
# Read off the top of the bucket and ensure correctness of output
|
||||
_read_test(sess)
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def testBucketBySequenceLength(self):
|
||||
self._testBucketBySequenceLength(allow_small_batch=False)
|
||||
|
||||
def testBucketBySequenceLengthAllow(self):
|
||||
self._testBucketBySequenceLength(allow_small_batch=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -171,7 +171,6 @@ cc_library(
|
||||
"platform/env.h",
|
||||
"platform/file_system.h",
|
||||
"platform/fingerprint.h",
|
||||
"platform/hexagon/profile_utils/cpu_utils.h",
|
||||
"platform/init_main.h",
|
||||
"platform/logging.h",
|
||||
"platform/macros.h",
|
||||
@ -179,6 +178,7 @@ cc_library(
|
||||
"platform/net.h",
|
||||
"platform/mutex.h",
|
||||
"platform/notification.h",
|
||||
"platform/profile_utils/cpu_utils.h",
|
||||
"platform/protobuf.h", # TODO(josh11b): make internal
|
||||
"platform/regexp.h",
|
||||
"platform/strong_hash.h",
|
||||
@ -862,8 +862,8 @@ cc_library(
|
||||
"lib/**/*.cc",
|
||||
"platform/*.h",
|
||||
"platform/*.cc",
|
||||
"platform/hexagon/**/*.h",
|
||||
"platform/hexagon/**/*.cc",
|
||||
"platform/profile_utils/**/*.h",
|
||||
"platform/profile_utils/**/*.cc",
|
||||
] + tf_additional_lib_srcs(),
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
@ -891,7 +891,7 @@ cc_library(
|
||||
"lib/io/snappy/snappy_inputbuffer.h",
|
||||
"lib/io/snappy/snappy_outputbuffer.h",
|
||||
"lib/io/zlib_compression_options.h",
|
||||
"lib/io/zlib_inputbuffer.h",
|
||||
"lib/io/zlib_inputstream.h",
|
||||
"lib/io/zlib_outputbuffer.h",
|
||||
"lib/jpeg/jpeg_handle.h",
|
||||
"lib/png/png_io.h",
|
||||
@ -1348,11 +1348,11 @@ tf_cc_tests(
|
||||
"lib/strings/stringprintf_test.cc",
|
||||
"lib/wav/wav_io_test.cc",
|
||||
"platform/fingerprint_test.cc",
|
||||
"platform/hexagon/profile_utils/cpu_utils_test.cc",
|
||||
"platform/integral_types_test.cc",
|
||||
"platform/logging_test.cc",
|
||||
"platform/net_test.cc",
|
||||
"platform/port_test.cc",
|
||||
"platform/profile_utils/cpu_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":lib",
|
||||
|
@ -157,7 +157,7 @@ bool BFCAllocator::Extend(size_t rounded_bytes) {
|
||||
InsertFreeChunkIntoBin(h);
|
||||
|
||||
// Invoke visitors on newly allocated region.
|
||||
for (auto visitor : region_visitors_) {
|
||||
for (const auto& visitor : region_visitors_) {
|
||||
visitor(mem_addr, bytes);
|
||||
}
|
||||
return true;
|
||||
|
@ -279,7 +279,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
||||
edges_to_remove.push_back(out_edge);
|
||||
}
|
||||
}
|
||||
string node_name = n->name();
|
||||
const string& node_name = n->name();
|
||||
Node* constant_node;
|
||||
auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
|
||||
"__cf__", UniqueConstantId()),
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -26,7 +27,9 @@ namespace {
|
||||
|
||||
struct RegistrationInfo {
|
||||
RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
|
||||
: sender_device_type(s), receiver_device_type(r), copy_function(cf) {}
|
||||
: sender_device_type(std::move(s)),
|
||||
receiver_device_type(r),
|
||||
copy_function(cf) {}
|
||||
DeviceType sender_device_type;
|
||||
DeviceType receiver_device_type;
|
||||
CopyTensor::CopyFunction copy_function;
|
||||
|
@ -71,9 +71,9 @@ std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
|
||||
std::vector<DeviceType> result;
|
||||
std::set<string> seen;
|
||||
for (Device* d : devices_) {
|
||||
auto t = d->device_type();
|
||||
const auto& t = d->device_type();
|
||||
if (seen.insert(t).second) {
|
||||
result.emplace_back(DeviceType(t));
|
||||
result.emplace_back(t);
|
||||
}
|
||||
}
|
||||
std::sort(result.begin(), result.end(), DeviceTypeComparator);
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -113,6 +112,77 @@ string GetRendezvousKey(const string& tensor_name,
|
||||
|
||||
} // namespace
|
||||
|
||||
class DirectSessionFactory : public SessionFactory {
|
||||
public:
|
||||
DirectSessionFactory() {}
|
||||
|
||||
bool AcceptsOptions(const SessionOptions& options) override {
|
||||
return options.target.empty();
|
||||
}
|
||||
|
||||
Session* NewSession(const SessionOptions& options) override {
|
||||
// Must do this before the CPU allocator is created.
|
||||
if (options.config.graph_options().build_cost_model() > 0) {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
Status s = DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DirectSession* session =
|
||||
new DirectSession(options, new DeviceMgr(devices), this);
|
||||
{
|
||||
mutex_lock l(sessions_lock_);
|
||||
sessions_.push_back(session);
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
Status Reset(const SessionOptions& options,
|
||||
const std::vector<string>& containers) override {
|
||||
std::vector<DirectSession*> sessions_to_reset;
|
||||
{
|
||||
mutex_lock l(sessions_lock_);
|
||||
// We create a copy to ensure that we don't have a deadlock when
|
||||
// session->Close calls the DirectSessionFactory.Deregister, which
|
||||
// acquires sessions_lock_.
|
||||
std::swap(sessions_to_reset, sessions_);
|
||||
}
|
||||
Status s;
|
||||
for (auto session : sessions_to_reset) {
|
||||
s.Update(session->Reset(containers));
|
||||
}
|
||||
// TODO(suharshs): Change the Reset behavior of all SessionFactories so that
|
||||
// it doesn't close the sessions?
|
||||
for (auto session : sessions_to_reset) {
|
||||
s.Update(session->Close());
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void Deregister(const DirectSession* session) {
|
||||
mutex_lock l(sessions_lock_);
|
||||
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
|
||||
sessions_.end());
|
||||
}
|
||||
|
||||
private:
|
||||
mutex sessions_lock_;
|
||||
std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
|
||||
};
|
||||
|
||||
class DirectSessionRegistrar {
|
||||
public:
|
||||
DirectSessionRegistrar() {
|
||||
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
||||
}
|
||||
};
|
||||
static DirectSessionRegistrar registrar;
|
||||
|
||||
std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
|
||||
|
||||
// NOTE: On Android with a single device, there is never
|
||||
@ -146,10 +216,13 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
|
||||
}
|
||||
|
||||
DirectSession::DirectSession(const SessionOptions& options,
|
||||
const DeviceMgr* device_mgr)
|
||||
const DeviceMgr* device_mgr,
|
||||
DirectSessionFactory* const factory)
|
||||
: options_(options),
|
||||
device_mgr_(device_mgr),
|
||||
factory_(factory),
|
||||
cancellation_manager_(new CancellationManager()),
|
||||
closed_(false),
|
||||
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
|
||||
if (options_.config.session_inter_op_thread_pool_size() > 0) {
|
||||
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
|
||||
@ -194,6 +267,7 @@ DirectSession::DirectSession(const SessionOptions& options,
|
||||
}
|
||||
|
||||
DirectSession::~DirectSession() {
|
||||
if (!closed_) Close();
|
||||
for (auto& it : partial_runs_) {
|
||||
it.second.reset(nullptr);
|
||||
}
|
||||
@ -237,6 +311,7 @@ Status DirectSession::Create(const GraphDef& graph) {
|
||||
}
|
||||
|
||||
Status DirectSession::Extend(const GraphDef& graph) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
mutex_lock l(graph_def_lock_);
|
||||
return ExtendLocked(graph);
|
||||
}
|
||||
@ -267,6 +342,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
direct_session_runs->GetCell()->IncrementBy(1);
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
@ -412,6 +488,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
@ -487,6 +564,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
std::vector<string> parts = str_util::Split(handle, ';');
|
||||
const string& key = parts[0];
|
||||
// Get the executors for this partial run.
|
||||
@ -1002,8 +1080,20 @@ Status DirectSession::CreateGraphs(
|
||||
return s;
|
||||
}
|
||||
|
||||
::tensorflow::Status DirectSession::Reset(
|
||||
const std::vector<string>& containers) {
|
||||
device_mgr_->ClearContainers(containers);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
::tensorflow::Status DirectSession::Close() {
|
||||
cancellation_manager_->StartCancel();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (closed_) return ::tensorflow::Status::OK();
|
||||
closed_ = true;
|
||||
}
|
||||
if (factory_ != nullptr) factory_->Deregister(this);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
@ -1051,37 +1141,4 @@ void DirectSession::WaitForNotification(RunState* run_state,
|
||||
}
|
||||
}
|
||||
|
||||
class DirectSessionFactory : public SessionFactory {
|
||||
public:
|
||||
DirectSessionFactory() {}
|
||||
|
||||
bool AcceptsOptions(const SessionOptions& options) override {
|
||||
return options.target.empty();
|
||||
}
|
||||
|
||||
Session* NewSession(const SessionOptions& options) override {
|
||||
// Must do this before the CPU allocator is created.
|
||||
if (options.config.graph_options().build_cost_model() > 0) {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
Status s = DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new DirectSession(options, new DeviceMgr(devices));
|
||||
}
|
||||
};
|
||||
|
||||
class DirectSessionRegistrar {
|
||||
public:
|
||||
DirectSessionRegistrar() {
|
||||
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
||||
}
|
||||
};
|
||||
static DirectSessionRegistrar registrar;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
@ -47,11 +48,18 @@ namespace tensorflow {
|
||||
class CostModel;
|
||||
class DebugGateway;
|
||||
class Device;
|
||||
class DirectSessionFactory;
|
||||
|
||||
class DirectSession : public Session {
|
||||
public:
|
||||
typedef std::function<void(Session*)> CloseCallback;
|
||||
|
||||
// Takes ownership of 'device_mgr'.
|
||||
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
|
||||
// 'factory' is used to unregister the DirectSession with 'factory' when its
|
||||
// closed. This ensures that Reset requests from the 'factory' don't get sent
|
||||
// to sessions that are already closed.
|
||||
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
|
||||
DirectSessionFactory* factory);
|
||||
~DirectSession() override;
|
||||
|
||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||
@ -83,6 +91,10 @@ class DirectSession : public Session {
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) override;
|
||||
|
||||
// Reset clears 'containers' from the device_mgr of the DirectSession.
|
||||
// If 'containers' is empty, then Reset clears the default container.
|
||||
::tensorflow::Status Reset(const std::vector<string>& containers);
|
||||
|
||||
::tensorflow::Status Close() override;
|
||||
|
||||
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
|
||||
@ -198,6 +210,12 @@ class DirectSession : public Session {
|
||||
// operation_timeout_in_ms is greater than 0.
|
||||
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
|
||||
|
||||
::tensorflow::Status CheckNotClosed() {
|
||||
mutex_lock l(mu_);
|
||||
if (closed_) return errors::Cancelled("Session has been closed.");
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const SessionOptions options_;
|
||||
|
||||
// Device structures.
|
||||
@ -232,10 +250,12 @@ class DirectSession : public Session {
|
||||
// This holds all the tensors that are currently alive in the session.
|
||||
SessionState session_state_;
|
||||
|
||||
DirectSessionFactory* const factory_; // not owned
|
||||
CancellationManager* cancellation_manager_;
|
||||
|
||||
// Saves and restores device placements for stateful nodes.
|
||||
mutex mu_;
|
||||
|
||||
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
|
||||
// is true, such as "params" and "queue" nodes. Once placed these
|
||||
// nodes can not be moved to a different device. Maps node names to
|
||||
@ -251,6 +271,9 @@ class DirectSession : public Session {
|
||||
// library; it copies and modifies the function library.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
|
||||
// true if the Session has been Closed.
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
|
||||
// For generating unique names.
|
||||
int64 name_counter_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
|
@ -397,6 +397,14 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
|
||||
|
||||
// Feed [first_const, first_const]
|
||||
s = session->Run(
|
||||
{{first_const->name(), value_11}, {first_const->name(), value_22}},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
||||
&outputs);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
|
||||
}
|
||||
|
||||
REGISTER_OP("Darth")
|
||||
@ -970,5 +978,129 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||
// Construct a graph with a variable and a single assign.
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor t(DT_FLOAT, TensorShape({}));
|
||||
t.scalar<float>()() = {1.2};
|
||||
Node* var_val = test::graph::Constant(&g, t);
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
// Assign a value to the var.
|
||||
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr));
|
||||
|
||||
// Run a read on the variable to ensure that it works.
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(
|
||||
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||
outputs.clear();
|
||||
|
||||
// Close the session.
|
||||
session->Close();
|
||||
|
||||
// Run the read on the variable to get an error.
|
||||
Status s = session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
||||
Tensor first_value(DT_FLOAT, TensorShape({}));
|
||||
first_value.scalar<float>()() = 1.0;
|
||||
Node* first_const = test::graph::Constant(&g, first_value);
|
||||
Node* first_identity = test::graph::Identity(&g, first_const);
|
||||
|
||||
Tensor second_value(DT_FLOAT, TensorShape({}));
|
||||
second_value.scalar<float>()() = 2.0;
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||
Node* third_identity = test::graph::Identity(&g, third);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
string handle;
|
||||
Status s = session->PRunSetup(
|
||||
{first_const->name(), second_const->name()},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0",
|
||||
third_identity->name() + ":0"},
|
||||
{}, &handle);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
Tensor value_11(DT_FLOAT, TensorShape({}));
|
||||
value_11.scalar<float>()() = 11.0;
|
||||
Tensor value_22(DT_FLOAT, TensorShape({}));
|
||||
value_22.scalar<float>()() = 22.0;
|
||||
|
||||
// Close the session.
|
||||
session->Close();
|
||||
|
||||
// Feed first_const, fetch first_identity
|
||||
s = session->PRun(handle, {{first_const->name(), value_11}},
|
||||
{first_identity->name() + ":0"}, &outputs);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionReset) {
|
||||
// Construct a graph with a variable and a single assign.
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor t(DT_FLOAT, TensorShape({}));
|
||||
t.scalar<float>()() = {1.2};
|
||||
Node* var_val = test::graph::Constant(&g, t);
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
// Assign a value to the var.
|
||||
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr));
|
||||
|
||||
// Run a read on the variable to ensure that it works.
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(
|
||||
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||
outputs.clear();
|
||||
|
||||
// Reset the containers.
|
||||
Reset(options, {});
|
||||
|
||||
// Run the read on the variable to get an error.
|
||||
// TODO(suharshs): This test only works because we close the Session in Reset.
|
||||
// If we change the behavior of Reset to not close the Session, this test will
|
||||
// fail, since the Variable buffer is cached by var.
|
||||
Status s = session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -144,7 +144,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
|
||||
void Init(const std::vector<FunctionDef>& flib) {
|
||||
FunctionDefLibrary proto;
|
||||
for (auto fdef : flib) *(proto.add_function()) = fdef;
|
||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||
delete lib_def_;
|
||||
lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto);
|
||||
delete lib_;
|
||||
|
@ -95,7 +95,7 @@ void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream,
|
||||
FlushAccumulatedTensors();
|
||||
}
|
||||
accumulated_stream_ = stream;
|
||||
for (auto t : tensors) {
|
||||
for (const auto& t : tensors) {
|
||||
// accumulated_tensors_ takes over ownership of the reference to "t"
|
||||
accumulated_tensors_->push_back(t);
|
||||
accumulated_tensor_bytes_ += t.TotalBytes();
|
||||
|
@ -129,7 +129,7 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
||||
// Nodes should be assigned to streams by op type.
|
||||
for (const auto& it : node_to_stream_id) {
|
||||
Node* n = g.FindNodeId(it.first);
|
||||
const string op = n->type_string();
|
||||
const string& op = n->type_string();
|
||||
const int stream = it.second;
|
||||
if (op == "Const") {
|
||||
EXPECT_EQ(stream, 90);
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <sys/mman.h> // for munmap
|
||||
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -31,7 +32,7 @@ namespace tensorflow {
|
||||
PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
|
||||
SubAllocator* allocator,
|
||||
RoundUpInterface* size_rounder, string name)
|
||||
: name_(name),
|
||||
: name_(std::move(name)),
|
||||
has_size_limit_(pool_size_limit > 0),
|
||||
auto_resize_(auto_resize),
|
||||
pool_size_limit_(pool_size_limit),
|
||||
@ -125,7 +126,7 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
||||
return PrepareChunk(r, alignment, num_bytes);
|
||||
} else {
|
||||
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
|
||||
for (auto v : alloc_visitors_) {
|
||||
for (const auto& v : alloc_visitors_) {
|
||||
v(ptr, num_bytes);
|
||||
}
|
||||
return PrepareChunk(ptr, alignment, num_bytes);
|
||||
@ -137,7 +138,7 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
|
||||
ChunkPrefix* cp = FindPrefix(ptr);
|
||||
CHECK_LE((void*)cp, (void*)ptr);
|
||||
if (!has_size_limit_ && !auto_resize_) {
|
||||
for (auto v : free_visitors_) {
|
||||
for (const auto& v : free_visitors_) {
|
||||
v(cp, cp->num_bytes);
|
||||
}
|
||||
allocator_->Free(cp, cp->num_bytes);
|
||||
@ -160,7 +161,7 @@ void PoolAllocator::Clear() {
|
||||
mutex_lock lock(mutex_);
|
||||
for (auto iter : pool_) {
|
||||
PtrRecord* pr = iter.second;
|
||||
for (auto v : free_visitors_) {
|
||||
for (const auto& v : free_visitors_) {
|
||||
v(pr->ptr, pr->num_bytes);
|
||||
}
|
||||
allocator_->Free(pr->ptr, pr->num_bytes);
|
||||
@ -217,7 +218,7 @@ void PoolAllocator::EvictOne() {
|
||||
DCHECK(iter != pool_.end());
|
||||
}
|
||||
pool_.erase(iter);
|
||||
for (auto v : free_visitors_) {
|
||||
for (const auto& v : free_visitors_) {
|
||||
v(prec->ptr, prec->num_bytes);
|
||||
}
|
||||
allocator_->Free(prec->ptr, prec->num_bytes);
|
||||
|
@ -181,12 +181,25 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
|
||||
// different numa_nodes. For now, just one.
|
||||
numa_node = 0;
|
||||
mutex_lock lock(mu_);
|
||||
|
||||
// Find the first valid StreamExecutor to request CUDA host memory
|
||||
// through, since any will work.
|
||||
//
|
||||
// This search isn't super clean, and it would be nice to use a
|
||||
// better source of information about which executor to use. For
|
||||
// example, process_state could maybe save the first stream executor
|
||||
// it knows is valid.
|
||||
gpu::StreamExecutor* se = nullptr;
|
||||
for (size_t i = 0; i < gpu_allocators_.size(); ++i) {
|
||||
if (gpu_allocators_[i] != nullptr) {
|
||||
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_NE(nullptr, se);
|
||||
|
||||
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
|
||||
// CUDAHost alloc the same across all gpus, so just get the
|
||||
// executor for the first device.
|
||||
gpu::Platform* gpu_platform = GPUMachineManager();
|
||||
gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
|
||||
CHECK(se);
|
||||
Allocator* allocator = nullptr;
|
||||
static constexpr bool kCudaHostMemoryUseBFC = true;
|
||||
if (kCudaHostMemoryUseBFC) {
|
||||
|
@ -44,6 +44,7 @@ SimpleGraphExecutionState::SimpleGraphExecutionState(
|
||||
const SimpleGraphExecutionStateOptions& options)
|
||||
: device_set_(options.device_set),
|
||||
session_options_(options.session_options),
|
||||
costs_(true /*is_global*/),
|
||||
flib_def_(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), func_def_lib)),
|
||||
graph_(nullptr) {
|
||||
@ -53,6 +54,7 @@ SimpleGraphExecutionState::SimpleGraphExecutionState(
|
||||
|
||||
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
|
||||
mutex_lock l(mu_);
|
||||
node_name_to_cost_id_map_.clear();
|
||||
delete graph_;
|
||||
}
|
||||
|
||||
@ -178,6 +180,10 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
||||
GraphConstructorOptions opts;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(opts, original_graph_def_, new_graph.get()));
|
||||
for (const Node* n : new_graph->nodes()) {
|
||||
VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
|
||||
node_name_to_cost_id_map_[n->name()] = n->cost_id();
|
||||
}
|
||||
if (session_options_ &&
|
||||
session_options_->config.graph_options().place_pruned_graph()) {
|
||||
// Rewrite the graph before placement.
|
||||
@ -189,10 +195,15 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
||||
// Save stateful placements before placing.
|
||||
RestoreStatefulNodes(new_graph.get());
|
||||
|
||||
CostModel costs(true /*is_global*/);
|
||||
costs_.InitFromGraph(*new_graph.get());
|
||||
costs.MergeFromGlobal(costs_);
|
||||
|
||||
GraphOptimizationPassOptions optimization_options;
|
||||
optimization_options.session_options = session_options_;
|
||||
optimization_options.graph = &new_graph;
|
||||
optimization_options.flib_def = flib_def_.get();
|
||||
optimization_options.cost_model = &costs;
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
|
||||
@ -209,6 +220,31 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::UpdateCostsFromStats(const StepStats& ss) {
|
||||
mutex_lock l(mu_);
|
||||
costs_.MergeFromStats(node_name_to_cost_id_map_, ss);
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) {
|
||||
mutex_lock l(mu_);
|
||||
costs->MergeFromGlobal(costs_);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
|
||||
NodeDef* out) {
|
||||
NodeNameToCostIdMap::const_iterator iter =
|
||||
node_name_to_cost_id_map_.find(name);
|
||||
if (iter != node_name_to_cost_id_map_.end()) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
const Node* node = graph_->FindNodeId(iter->second);
|
||||
if (node) {
|
||||
*out = node->def();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::NotFound("Node name: ", name);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::BuildGraph(
|
||||
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
|
||||
VLOG(1) << "BuildGraph";
|
||||
@ -234,10 +270,14 @@ Status SimpleGraphExecutionState::BuildGraph(
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib(
|
||||
new FunctionLibraryDefinition(*flib_def_));
|
||||
|
||||
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
||||
CostModel costs(true /*is_global*/);
|
||||
costs.MergeFromGlobal(costs_);
|
||||
GraphOptimizationPassOptions optimization_options;
|
||||
optimization_options.session_options = session_options_;
|
||||
optimization_options.graph = &ng;
|
||||
optimization_options.flib_def = flib.get();
|
||||
optimization_options.cost_model = &costs;
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||
|
@ -119,6 +119,22 @@ class SimpleGraphExecutionState {
|
||||
// execution, e.g. a send, recv or feed node.
|
||||
Status GlobalNodeDefByName(const string& name, NodeDef* out);
|
||||
|
||||
// Sums execution statistics in "ss" into the CostModel.
|
||||
void UpdateCostsFromStats(const StepStats& ss);
|
||||
|
||||
Microseconds TimeEstimate(const Node* n) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
return costs_.TimeEstimate(n);
|
||||
}
|
||||
|
||||
Bytes SizeEstimate(const Node* n, int output_slot) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
return costs_.SizeEstimate(n, output_slot);
|
||||
}
|
||||
|
||||
// Merge the cost model maintained by this graph_execution_state to 'costs'.
|
||||
void MergeCostsFromGlobal(CostModel* costs);
|
||||
|
||||
// The graph returned by BuildGraph may contain only the pruned
|
||||
// graph, whereas some clients may want access to the full graph.
|
||||
const Graph* full_graph() {
|
||||
@ -162,6 +178,11 @@ class SimpleGraphExecutionState {
|
||||
const DeviceSet* device_set_; // Not owned
|
||||
const SessionOptions* session_options_; // Not owned
|
||||
|
||||
CostModel costs_ GUARDED_BY(mu_);
|
||||
|
||||
// Map from name to Node for the full graph in placed_.
|
||||
NodeNameToCostIdMap node_name_to_cost_id_map_;
|
||||
|
||||
// 'flib_def_' is initialized from the initial graph def's library,
|
||||
// and may be updated by a graph optimization pass.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
|
@ -42,7 +42,7 @@ std::vector<Device*> FilterSupportedDevices(
|
||||
const std::vector<Device*>& devices,
|
||||
const DeviceTypeVector& supported_device_types) {
|
||||
std::vector<Device*> filtered_devices;
|
||||
for (DeviceType d : supported_device_types) {
|
||||
for (const DeviceType& d : supported_device_types) {
|
||||
for (Device* device : devices) {
|
||||
if (DeviceType(device->attributes().device_type()) == d) {
|
||||
filtered_devices.emplace_back(device);
|
||||
@ -238,11 +238,15 @@ class ColocationGraph {
|
||||
// members_[old_root].supported_device_types.
|
||||
MergeSupportedDevices(&members_[new_root].supported_device_types,
|
||||
members_[old_root].supported_device_types);
|
||||
if (members_[x_root].supported_device_types.size() == 0) {
|
||||
if (members_[new_root].supported_device_types.size() == 0) {
|
||||
string debug_info;
|
||||
AddDebugInfo(x_root, &debug_info);
|
||||
AddDebugInfo(y_root, &debug_info);
|
||||
return errors::InvalidArgument(
|
||||
"Cannot colocate nodes '", x.name(), "' and '", y.name(),
|
||||
"' because no device type supports both of those nodes and the "
|
||||
"other nodes colocated with them");
|
||||
"other nodes colocated with them.",
|
||||
debug_info);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -495,7 +499,7 @@ class ColocationGraph {
|
||||
"' does not match any device");
|
||||
}
|
||||
|
||||
for (DeviceType d : member->supported_device_types) {
|
||||
for (const DeviceType& d : member->supported_device_types) {
|
||||
if (DeviceType(assigned_device->attributes().device_type()) == d) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -545,9 +549,9 @@ class ColocationGraph {
|
||||
target->clear();
|
||||
|
||||
// Iterate in priority order.
|
||||
for (DeviceType device_type : temp) {
|
||||
for (const DeviceType& device_type : temp) {
|
||||
bool found = false;
|
||||
for (DeviceType other_device_type : other) {
|
||||
for (const DeviceType& other_device_type : other) {
|
||||
if (device_type == other_device_type) {
|
||||
found = true;
|
||||
break;
|
||||
|
@ -689,8 +689,9 @@ TEST_F(SimplePlacerTest,
|
||||
Status s = Place(&g);
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.error_message())
|
||||
.contains("Cannot assign a device to node 'var3': Node had no "
|
||||
"OpKernel registered"));
|
||||
.contains("Cannot colocate nodes 'var3' and 'assign3' because no "
|
||||
"device type supports both of those nodes and the other "
|
||||
"nodes colocated with them."));
|
||||
}
|
||||
|
||||
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
|
||||
|
@ -54,9 +54,9 @@ namespace tensorflow {
|
||||
namespace {
|
||||
// A little bit of per-step state.
|
||||
struct PerStepState {
|
||||
bool collect_timeline;
|
||||
Microseconds start_micros = Microseconds(0);
|
||||
Microseconds end_micros = Microseconds(0);
|
||||
std::vector<StepStats> step_stats; // per partition
|
||||
};
|
||||
|
||||
// A session encapsulates a graph computation (resource allocation,
|
||||
@ -522,6 +522,10 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
|
||||
// Prepares a number of calls to workers. One call per partition.
|
||||
ExecutorOpts exec_opts;
|
||||
if (pss->collect_timeline) {
|
||||
exec_opts.set_record_timeline(true);
|
||||
}
|
||||
|
||||
const int num = partitions_.size();
|
||||
RunManyGraphs calls(num);
|
||||
|
||||
@ -597,8 +601,9 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (calls.get(i)->resp.has_step_stats()) {
|
||||
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
|
||||
if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) {
|
||||
resp->mutable_metadata()->mutable_step_stats()->MergeFrom(
|
||||
calls.get(i)->resp.step_stats());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -953,6 +958,8 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||
|
||||
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
|
||||
execution_state_.get(), &pss, opts,
|
||||
*req, resp, cancellation_manager_));
|
||||
|
@ -162,6 +162,8 @@ Status GrpcSession::Run(const RunOptions& run_options,
|
||||
RunStepRequest req;
|
||||
RunStepResponse resp;
|
||||
|
||||
*req.mutable_options() = run_options;
|
||||
|
||||
for (const auto& it : inputs) {
|
||||
Tensor input_tensor = it.second;
|
||||
auto feed = req.add_feed();
|
||||
@ -206,6 +208,10 @@ Status GrpcSession::Run(const RunOptions& run_options,
|
||||
(*outputs)[fetch_it->second] = output;
|
||||
}
|
||||
|
||||
if (run_metadata) {
|
||||
run_metadata->Swap(resp.mutable_metadata());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -75,6 +75,9 @@ static SessionOptions Options(const string& target, int placement_period) {
|
||||
// string.
|
||||
options.target = strings::StrCat("grpc://", target);
|
||||
options.config.set_placement_period(placement_period);
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions::L0);
|
||||
return options;
|
||||
}
|
||||
|
||||
@ -307,9 +310,29 @@ TEST(GrpcSessionTest, MultiDevices) {
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
|
||||
RunOptions options;
|
||||
options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata metadata;
|
||||
TF_CHECK_OK(
|
||||
session->Run(options, {}, {c->name()}, {}, &outputs, &metadata));
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
IsSingleFloatValue(outputs[0], 6.0 * kSize);
|
||||
|
||||
const StepStats& ss = metadata.step_stats();
|
||||
// NOTE(mrry): We only assert that `c` is placed correctly,
|
||||
// because the current placement algorithm will move its
|
||||
// inputs to be colocated with it, when it is the sole
|
||||
// consumer.
|
||||
bool c_placed_correctly = false;
|
||||
for (const auto& dev : ss.dev_stats()) {
|
||||
for (const auto& node : dev.node_stats()) {
|
||||
if (node.node_name() == c->name() &&
|
||||
dev.device() == c_dev.name()) {
|
||||
c_placed_correctly = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(c_placed_correctly);
|
||||
}
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
@ -325,7 +325,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
return;
|
||||
}
|
||||
StepStatsCollector* collector = nullptr;
|
||||
// TODO(mrry): Collect results from a profiler if available.
|
||||
if (call->request.exec_opts().record_timeline()) {
|
||||
collector = new StepStatsCollector(call->response.mutable_step_stats());
|
||||
// TODO(mrry,pbar): GPU tracing for distributed steps.
|
||||
}
|
||||
CancellationManager* cm = new CancellationManager;
|
||||
call->SetCancelCallback([this, cm, step_id]() {
|
||||
cm->StartCancel();
|
||||
@ -340,7 +343,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
}
|
||||
env_->graph_mgr->ExecuteAsync(
|
||||
call->request.graph_handle(), step_id, call->request.exec_opts(),
|
||||
collector, cm, in, out, [this, call, cm, out, token](Status s) {
|
||||
collector, cm, in, out,
|
||||
[this, call, cm, out, token, collector](Status s) {
|
||||
call->ClearCancelCallback();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -359,6 +363,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
val.AsProtoField(proto);
|
||||
}
|
||||
}
|
||||
delete collector;
|
||||
delete out;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
|
@ -39,6 +39,9 @@ message CostGraphDef {
|
||||
// Temporary memory used by this node.
|
||||
int64 temporary_memory_size = 6;
|
||||
|
||||
// Estimate of the computational cost of this node.
|
||||
int64 compute_cost = 9;
|
||||
|
||||
// If true, the output is permanent: it can't be discarded, because this
|
||||
// node is part of the "final output". Nodes may depend on final nodes.
|
||||
bool is_final = 7;
|
||||
|
@ -861,11 +861,11 @@ string DebugString(const GraphDef& instantiated_func_def) {
|
||||
|
||||
string DebugStringWhole(const GraphDef& gdef) {
|
||||
string ret;
|
||||
for (auto fdef : gdef.library().function()) {
|
||||
for (const auto& fdef : gdef.library().function()) {
|
||||
strings::StrAppend(&ret, Print(fdef));
|
||||
}
|
||||
strings::StrAppend(&ret, "\n");
|
||||
for (auto ndef : gdef.node()) {
|
||||
for (const auto& ndef : gdef.node()) {
|
||||
strings::StrAppend(&ret, Print(ndef), "\n");
|
||||
}
|
||||
return ret;
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class CancellationManager;
|
||||
class Node;
|
||||
class OpKernel;
|
||||
class ResourceMgr;
|
||||
|
||||
|
@ -31,11 +31,11 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
||||
VersionDef* versions = g.mutable_versions();
|
||||
versions->set_producer(TF_GRAPH_DEF_VERSION);
|
||||
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
||||
for (auto n : nodes) {
|
||||
for (const auto& n : nodes) {
|
||||
*(g.add_node()) = n;
|
||||
}
|
||||
auto lib = g.mutable_library();
|
||||
for (auto f : funcs) {
|
||||
for (const auto& f : funcs) {
|
||||
*(lib->add_function()) = f;
|
||||
}
|
||||
return g;
|
||||
@ -49,7 +49,7 @@ NodeDef NDef(const string& name, const string& op,
|
||||
NodeDef n;
|
||||
n.set_name(name);
|
||||
n.set_op(op);
|
||||
for (auto in : inputs) n.add_input(in);
|
||||
for (const auto& in : inputs) n.add_input(in);
|
||||
n.set_device(device);
|
||||
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
|
||||
return n;
|
||||
|
@ -60,7 +60,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
|
||||
|
||||
Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
|
||||
const AttrValue& allowed_values(attr.allowed_values());
|
||||
for (auto allowed : allowed_values.list().s()) {
|
||||
for (const auto& allowed : allowed_values.list().s()) {
|
||||
if (str == allowed) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -381,7 +381,7 @@ class OpKernelBuilderTest : public ::testing::Test {
|
||||
DeviceTypeVector devices;
|
||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||
bool found = false;
|
||||
for (DeviceType dt : devices) {
|
||||
for (const DeviceType& dt : devices) {
|
||||
if (dt == device_type) {
|
||||
found = true;
|
||||
}
|
||||
@ -414,7 +414,7 @@ class OpKernelBuilderTest : public ::testing::Test {
|
||||
DeviceTypeVector devices;
|
||||
if (errors::IsNotFound(status)) {
|
||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||
for (DeviceType dt : devices) {
|
||||
for (const DeviceType& dt : devices) {
|
||||
EXPECT_NE(dt, device_type);
|
||||
}
|
||||
} else {
|
||||
|
@ -200,7 +200,7 @@ class TensorShape {
|
||||
DataType data_type() const { return static_cast<DataType>(buf()[13]); }
|
||||
void set_data_type(DataType dt) {
|
||||
// We only have 8 bits available to store DataType, so make sure it fits
|
||||
DCHECK_LT(static_cast<uint32>(dt), 256);
|
||||
DCHECK_LT(static_cast<uint32>(dt), 256u);
|
||||
buf()[13] = static_cast<uint8>(dt);
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,9 @@ class TensorSlice {
|
||||
}
|
||||
|
||||
// If we have a full slice along dimension "d".
|
||||
bool IsFullAt(int d) const { return lengths_[d] < 0; }
|
||||
bool IsFullAt(int d) const {
|
||||
return lengths_[d] == kFullExtent && starts_[d] == 0;
|
||||
}
|
||||
|
||||
// If this is a full slice, i.e. IsFullAt(d) for every d.
|
||||
bool IsFull() const;
|
||||
|
@ -273,8 +273,8 @@ TEST(TensorSliceTest, Deserialization) {
|
||||
TensorSlice ts3(proto3);
|
||||
|
||||
// Both serializations should be interpreted the same.
|
||||
EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString());
|
||||
EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString());
|
||||
EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts2.DebugString());
|
||||
EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts3.DebugString());
|
||||
}
|
||||
|
||||
TEST(TensorSliceTest, UpdateToCover) {
|
||||
|
@ -326,7 +326,7 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
|
||||
|
||||
// A graph contains a bunch of constants.
|
||||
Graph g(OpRegistry::Global());
|
||||
for (auto val : {a, b, c, d, d, c, b, a}) {
|
||||
for (const auto& val : {a, b, c, d, d, c, b, a}) {
|
||||
test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
|
||||
}
|
||||
GraphDef gdef;
|
||||
|
@ -74,7 +74,7 @@ inline bool IsGradientNode(const Graph* graph, const Node* node) {
|
||||
// Returns true if the root tensor op type is known, false otherwise.
|
||||
bool FindType(const Graph* graph, const Node* node, bool* signed_input,
|
||||
bool* range_given, float* input_min, float* input_max) {
|
||||
const string src_op = node->type_string();
|
||||
const string& src_op = node->type_string();
|
||||
if (src_op == "Const" || src_op == "Variable") {
|
||||
*signed_input = true;
|
||||
*range_given = false;
|
||||
|
@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||
shape_inference::ShapeHandle shape) {
|
||||
auto c = GetContext(node);
|
||||
if (c == nullptr) {
|
||||
return errors::Internal("Could not find context for ", node->name());
|
||||
}
|
||||
|
||||
if (output_port < 0 || output_port >= node->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"output_port '", output_port, "' is out of range, ", "node '",
|
||||
node->name(), "' has ", node->num_outputs(), " outputs");
|
||||
}
|
||||
|
||||
// Check compatibility
|
||||
shape_inference::ShapeHandle existing_shape = c->output(output_port);
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused));
|
||||
|
||||
c->set_output(output_port, shape);
|
||||
|
||||
// TODO(vrv): Do we need to propagate the new shape through all
|
||||
// consumers that change their outputs? At the moment, python
|
||||
// does not do this, but this seems like a nice feature.
|
||||
|
||||
// TODO(vrv): We might need to keep track of the fact that the
|
||||
// existing shape is invalidated, in case we need to propagate
|
||||
// this information to remote workers.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
|
||||
const Tensor** input_tensor) const {
|
||||
*input_tensor = nullptr;
|
||||
|
@ -46,6 +46,14 @@ class ShapeRefiner {
|
||||
// - The shape inference function returns an error.
|
||||
Status AddNode(const Node* node);
|
||||
|
||||
// Sets 'node's 'output_port' output to have shape 'shape'.
|
||||
//
|
||||
// Returns an error if 'node' was not previously added to this
|
||||
// object, if 'output_port' is invalid, or if 'shape' is
|
||||
// not compatible with the existing shape of the output.
|
||||
Status SetShape(const Node* node, int output_port,
|
||||
shape_inference::ShapeHandle shape);
|
||||
|
||||
// Returns the InferenceContext for 'node', if present.
|
||||
shape_inference::InferenceContext* GetContext(const Node* node) const {
|
||||
auto it = node_to_context_.find(node);
|
||||
|
@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) {
|
||||
ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, SetShape) {
|
||||
ShapeRefiner m;
|
||||
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto a = ops::Const(root, {{1.0f}, {2.0f}});
|
||||
|
||||
TF_ASSERT_OK(m.AddNode(a.node()));
|
||||
|
||||
auto ic = m.GetContext(a.node());
|
||||
ASSERT_NE(nullptr, ic);
|
||||
shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
|
||||
TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
|
||||
EXPECT_SHAPE("[2,?]", m, a, 0);
|
||||
|
||||
// Out of range.
|
||||
ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
|
||||
ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
|
||||
|
||||
auto b = ops::Const(root, {{1.0f}, {2.0f}});
|
||||
// Forget to add node first.
|
||||
ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
|
||||
|
||||
// Set an incompatible shape (3 vs 2)
|
||||
h = ic->MakeShape({3, ic->UnknownDim()});
|
||||
ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateConstants) {
|
||||
// Reduction dimension is a variable, so we don't know its value.
|
||||
// So the output shape value is unknown (though its rank is known).
|
||||
|
@ -235,7 +235,15 @@ Status RewriteGraphForExecution(
|
||||
"Must specify at least one target to fetch or execute.");
|
||||
}
|
||||
|
||||
std::unordered_set<string> endpoints(fed_outputs.begin(), fed_outputs.end());
|
||||
std::unordered_set<string> endpoints;
|
||||
for (const string& endpoint_name : fed_outputs) {
|
||||
auto result = endpoints.insert(endpoint_name);
|
||||
if (!result.second) {
|
||||
return errors::InvalidArgument("Endpoint \"", endpoint_name,
|
||||
"\" fed more than once.");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& fetch : fetch_outputs) {
|
||||
if (endpoints.count(fetch) > 0) {
|
||||
return errors::InvalidArgument(fetch, " is both fed and fetched.");
|
||||
|
@ -491,6 +491,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "conv_ops_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":conv_ops",
|
||||
":image",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "example_parsing_ops_test",
|
||||
size = "large",
|
||||
@ -1325,6 +1346,7 @@ tf_kernel_library(
|
||||
hdrs = [
|
||||
"conv_grad_ops.h",
|
||||
"deep_conv2d.h",
|
||||
"gemm_functors.h",
|
||||
"winograd_transform.h",
|
||||
],
|
||||
prefix = "conv_ops",
|
||||
@ -1332,6 +1354,7 @@ tf_kernel_library(
|
||||
":bounds_check",
|
||||
":conv_2d",
|
||||
":conv_3d",
|
||||
":image_resizer_state",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1958,6 +1981,7 @@ filegroup(
|
||||
"control_flow_ops.h",
|
||||
"conv_2d.h",
|
||||
"conv_ops.h",
|
||||
"depthwise_conv_op.h",
|
||||
"image_resizer_state.h",
|
||||
"maxpooling_op.h",
|
||||
"pad_op.h",
|
||||
@ -1998,6 +2022,7 @@ filegroup(
|
||||
"cwise_op_div.cc",
|
||||
"cwise_op_equal_to.cc",
|
||||
"cwise_op_exp.cc",
|
||||
"cwise_op_floor.cc",
|
||||
"cwise_op_greater.cc",
|
||||
"cwise_op_inverse.cc",
|
||||
"cwise_op_isfinite.cc",
|
||||
@ -2017,6 +2042,7 @@ filegroup(
|
||||
"cwise_op_tanh.cc",
|
||||
"deep_conv2d.cc",
|
||||
"deep_conv2d.h",
|
||||
"depthwise_conv_op.cc",
|
||||
"dynamic_partition_op.cc",
|
||||
"winograd_transform.h",
|
||||
":android_extended_ops_headers",
|
||||
|
@ -67,7 +67,7 @@ class ArgOp : public OpKernel {
|
||||
input.shape().DebugString()));
|
||||
|
||||
TensorShape output_shape;
|
||||
TensorShape input_shape = input.shape();
|
||||
const TensorShape& input_shape = input.shape();
|
||||
for (int d = 0; d < input_dims - 1; ++d) {
|
||||
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ class ExtractGlimpseOp : public OpKernel {
|
||||
// depth).
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const TensorShape input_shape = input.shape();
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 num_dims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
context, num_dims == 4,
|
||||
|
@ -190,7 +190,7 @@ class ComputeAccidentalHitsOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& in_true_candidates = context->input(0);
|
||||
TensorShape in_true_candidates_shape = in_true_candidates.shape();
|
||||
const TensorShape& in_true_candidates_shape = in_true_candidates.shape();
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) &&
|
||||
in_true_candidates_shape.dim_size(1) == num_true_,
|
||||
errors::InvalidArgument(
|
||||
|
@ -37,15 +37,12 @@ struct scalar_const_op {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
|
||||
|
||||
template <typename Index>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(Index,
|
||||
Index = 0) const {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()() const {
|
||||
return *val;
|
||||
}
|
||||
|
||||
template <typename Index, typename PacketType = Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType
|
||||
packetOp(Index, Index = 0) const {
|
||||
template <typename PacketType = Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const {
|
||||
return internal::pset1<PacketType>(*val);
|
||||
}
|
||||
};
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_KERNELS_CONV_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -38,6 +39,16 @@ class LaunchConv2DOp {
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
// Used to keep track of persistent memory buffers used within the op.
|
||||
template <class T, size_t size>
|
||||
struct Im2ColBufferResource : public ResourceBase {
|
||||
// This mutex ensures that only a single operation at a time is able to use
|
||||
// the buffer memory held by this resource.
|
||||
mutex mu;
|
||||
T data[size];
|
||||
string DebugString() { return "Im2ColBufferResource"; }
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
class LaunchConv2DOp<Eigen::GpuDevice, T> {
|
||||
|
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
@ -0,0 +1,486 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Implements convolution operations with other kernels baked into the
|
||||
// processing, to optimize latency and memory usage.
|
||||
|
||||
#include <string.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Combines bilinear resizing and mirror padding into the im2col transformation
|
||||
// stage of convolution,
|
||||
template <class T1, class T2, class T3, class TGemmFunctor>
|
||||
class FusedResizeAndPadConvFunctor {
|
||||
public:
|
||||
void operator()(OpKernelContext* context, const Tensor& input,
|
||||
int input_batches, int resized_height, int resized_width,
|
||||
int padded_height, int padded_width, int input_depth,
|
||||
const T2* filter_data, int filter_height, int filter_width,
|
||||
int filter_count, int stride_rows, int stride_cols,
|
||||
Padding padding, T3* output_data, int output_height,
|
||||
int output_width, const ImageResizerState& st,
|
||||
int top_padding, int bottom_padding, int left_padding,
|
||||
int right_padding, int pad_offset) {
|
||||
if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
|
||||
(input_depth <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad input dimensions: "
|
||||
<< input_batches << ", " << padded_height << ", "
|
||||
<< padded_width << ", " << input_depth;
|
||||
return;
|
||||
}
|
||||
if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
|
||||
<< filter_width << ", " << filter_height << ", "
|
||||
<< filter_count;
|
||||
return;
|
||||
}
|
||||
if ((output_width <= 0) || (output_height <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad output width or height: "
|
||||
<< output_width << ", " << output_height;
|
||||
return;
|
||||
}
|
||||
|
||||
// These calculations define how the patches will be positioned within the
|
||||
// input image. The actual definitions are quite complex, and rely on the
|
||||
// previously-calculated output size.
|
||||
int filter_left_offset;
|
||||
int filter_top_offset;
|
||||
if (padding == VALID) {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
|
||||
2;
|
||||
filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
|
||||
padded_height + 1) /
|
||||
2;
|
||||
} else {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
|
||||
filter_top_offset =
|
||||
((output_height - 1) * stride_rows + filter_height - padded_height) /
|
||||
2;
|
||||
}
|
||||
|
||||
// The im2col buffer has # of patches rows, and # of filters cols.
|
||||
// It's laid out like this, in row major order in memory:
|
||||
// < filter value count >
|
||||
// ^ +---------------------+
|
||||
// patch | |
|
||||
// count | |
|
||||
// v +---------------------+
|
||||
// Each patch row contains a filter_width x filter_height patch of the
|
||||
// input, with the depth channel as the most contiguous in memory, followed
|
||||
// by the width, then the height. This is the standard memory order in the
|
||||
// image world if it helps to visualize it.
|
||||
const int filter_value_count = filter_width * filter_height * input_depth;
|
||||
|
||||
// We don't want to allocate a buffer to hold all the patches if the size is
|
||||
// going to be extremely large, so break it into chunks if it's bigger than
|
||||
// a limit. Each chunk will be processed serially, so we can refill the
|
||||
// buffer for the next chunk and reuse it, keeping maximum memory size down.
|
||||
// In this case, we've picked 16 megabytes as a reasonable limit.
|
||||
const size_t max_chunk_size = (16 * 1024 * 1024);
|
||||
OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size,
|
||||
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||
const size_t patches_per_chunk =
|
||||
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||
// Because memory allocation is very expensive on mobile platforms, try to
|
||||
// allocate a persistent buffer that will be kept around between calls. We
|
||||
// use TensorFlow's resource management to ensure that the memory will be
|
||||
// released when the session is over.
|
||||
Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource;
|
||||
std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator =
|
||||
[](Im2ColBufferResource<T1, max_chunk_size>** resource) {
|
||||
*resource = new Im2ColBufferResource<T1, max_chunk_size>();
|
||||
return Status::OK();
|
||||
};
|
||||
OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
|
||||
"Conv2d", "im2col_buffer",
|
||||
&im2col_buffer_resource, creator));
|
||||
// This means that multiple ops can't be run simultaneously on different
|
||||
// threads, because we have a single shared resource. The platforms this is
|
||||
// aimed at have intra-op parallelism as their focus though, so it shouldn't
|
||||
// be an issue.
|
||||
mutex_lock lock_buffer(im2col_buffer_resource->mu);
|
||||
core::ScopedUnref unref_buffer(im2col_buffer_resource);
|
||||
T1* im2col_buffer = im2col_buffer_resource->data;
|
||||
|
||||
typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>();
|
||||
|
||||
for (int batch = 0; batch < input_batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
|
||||
const int patch_index = (batch * output_width * output_height) +
|
||||
(out_y * output_width) + out_x;
|
||||
const int patch_index_within_chunk = patch_index % patches_per_chunk;
|
||||
T1* im2col_patch_start =
|
||||
im2col_buffer + (patch_index_within_chunk * filter_value_count);
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int conv_in_y = in_y_origin + filter_y;
|
||||
float in_y = (conv_in_y - top_padding);
|
||||
if (in_y < 0) {
|
||||
in_y = -(in_y + 1.0f - pad_offset);
|
||||
} else if (in_y >= resized_height) {
|
||||
in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
|
||||
}
|
||||
in_y *= st.height_scale;
|
||||
const int64 top_y_index = static_cast<int64>(std::floor(in_y));
|
||||
const int64 bottom_y_index = std::min(
|
||||
static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
|
||||
const T1 y_lerp = in_y - top_y_index;
|
||||
T1* im2col_row_start =
|
||||
im2col_patch_start + (filter_y * filter_width * input_depth);
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
const int conv_in_x = in_x_origin + filter_x;
|
||||
float in_x = (conv_in_x - left_padding);
|
||||
if (in_x < 0) {
|
||||
in_x = -(in_x + 1.0f - pad_offset);
|
||||
} else if (in_x >= resized_width) {
|
||||
in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
|
||||
}
|
||||
in_x *= st.width_scale;
|
||||
const int64 left_x_index = static_cast<int64>(std::floor(in_x));
|
||||
const int64 right_x_index = std::min(
|
||||
static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
|
||||
const T1 x_lerp = in_x - left_x_index;
|
||||
T1* im2col_row_pixel =
|
||||
im2col_row_start + (filter_x * input_depth);
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
T1 in_value;
|
||||
if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
|
||||
(conv_in_y >= 0) && (conv_in_y < padded_height)) {
|
||||
const T1 top_left(
|
||||
input_data(batch, top_y_index, left_x_index, in_channel));
|
||||
const T1 top_right(input_data(batch, top_y_index,
|
||||
right_x_index, in_channel));
|
||||
const T1 bottom_left(input_data(batch, bottom_y_index,
|
||||
left_x_index, in_channel));
|
||||
const T1 bottom_right(input_data(batch, bottom_y_index,
|
||||
right_x_index, in_channel));
|
||||
const T1 top = top_left + (top_right - top_left) * x_lerp;
|
||||
const T1 bottom =
|
||||
bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
in_value = top + (bottom - top) * y_lerp;
|
||||
} else {
|
||||
in_value = T1(0);
|
||||
}
|
||||
im2col_row_pixel[in_channel] = in_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
const bool is_last_in_chunk =
|
||||
(patch_index_within_chunk == (patches_per_chunk - 1));
|
||||
const bool is_last_overall =
|
||||
((batch == (input_batches - 1)) &&
|
||||
(out_y == (output_height - 1)) && (out_x == (output_width - 1)));
|
||||
if (is_last_in_chunk || is_last_overall) {
|
||||
// Now we've assembled a set of image patches into a matrix, apply a
|
||||
// GEMM matrix multiply of the patches as rows, times the filter
|
||||
// weights in columns, to get partial results in the output matrix.
|
||||
const int how_many_patches = patch_index_within_chunk + 1;
|
||||
const int m = how_many_patches;
|
||||
const int n = filter_count;
|
||||
const int k = filter_value_count;
|
||||
const int lda = filter_value_count;
|
||||
const int ldb = filter_count;
|
||||
const int ldc = filter_count;
|
||||
const size_t start_patch_index =
|
||||
patch_index - (how_many_patches - 1);
|
||||
T3* chunk_output_data =
|
||||
output_data + (start_patch_index * filter_count);
|
||||
TGemmFunctor gemm_functor;
|
||||
gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb,
|
||||
chunk_output_data, ldc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Implements a version of convolution with bilinear resizing and mirror padding
|
||||
// included.
|
||||
template <class T, class TConvFunctor>
|
||||
class FusedResizeConv2DUsingGemmOp : public OpKernel {
|
||||
public:
|
||||
explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("resize_align_corners", &align_corners_));
|
||||
MirrorPadMode mode;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
|
||||
|
||||
switch (mode) {
|
||||
case MirrorPadMode::SYMMETRIC: {
|
||||
offset_ = 0;
|
||||
break;
|
||||
}
|
||||
case MirrorPadMode::REFLECT: {
|
||||
offset_ = 1;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument(
|
||||
"mode must be either REFLECT or SYMMETRIC."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
OP_REQUIRES(context, strides_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
|
||||
const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
|
||||
OP_REQUIRES(
|
||||
context, stride_n == 1 && stride_c == 1,
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Input tensor is of the following dimensions:
|
||||
// [ batch, in_rows, in_cols, in_depth ]
|
||||
const Tensor& input = context->input(0);
|
||||
OP_REQUIRES(context, (input.shape().num_elements() > 0),
|
||||
errors::InvalidArgument("Input tensor can't be empty"));
|
||||
|
||||
ImageResizerState st(align_corners_);
|
||||
st.ValidateAndCalculateOutputSize(context, input);
|
||||
if (!context->status().ok()) return;
|
||||
const TensorShape resized_shape(
|
||||
{input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
|
||||
|
||||
const Tensor& paddings = context->input(2);
|
||||
|
||||
const int dims = resized_shape.dims();
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsMatrix(paddings.shape()) &&
|
||||
paddings.dim_size(1) == 2,
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
paddings.shape().DebugString()));
|
||||
const int fixed_dims =
|
||||
(allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
|
||||
? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
context, fixed_dims == paddings.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs: ",
|
||||
fixed_dims, " ", paddings.shape().DebugString(), " ",
|
||||
resized_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, dims == paddings.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs: ",
|
||||
dims, " ", paddings.shape().DebugString(), " ",
|
||||
resized_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, dims == 4,
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only supports four-dimensional inputs, but ",
|
||||
dims, " requested"));
|
||||
|
||||
// Compute the shape of the output tensor, and allocate it.
|
||||
TensorShape padded_shape;
|
||||
TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
|
||||
for (int d = 0; d < dims; ++d) {
|
||||
const int32 before =
|
||||
paddings_matrix(d, 0); // Pad before existing elements.
|
||||
const int32 after =
|
||||
paddings_matrix(d, 1); // Pad after exisitng elements.
|
||||
OP_REQUIRES(context, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("paddings must be non-negative: ",
|
||||
before, " ", after));
|
||||
if (offset_ == 0) { // SYMMETRIC mode.
|
||||
OP_REQUIRES(
|
||||
context, before <= resized_shape.dim_size(d) &&
|
||||
after <= resized_shape.dim_size(d),
|
||||
errors::InvalidArgument("paddings must be no greater "
|
||||
"than the dimension size: ",
|
||||
before, ", ", after, " greater than ",
|
||||
resized_shape.dim_size(d)));
|
||||
} else if (offset_ == 1) { // REFLECT mode.
|
||||
OP_REQUIRES(
|
||||
context, before < resized_shape.dim_size(d) &&
|
||||
after < resized_shape.dim_size(d),
|
||||
errors::InvalidArgument("paddings must be less than"
|
||||
" the dimension size: ",
|
||||
before, ", ", after, " not less than ",
|
||||
resized_shape.dim_size(d)));
|
||||
}
|
||||
padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only support spatial padding, not batches: ",
|
||||
paddings.DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only support spatial padding, not channels: ",
|
||||
paddings.DebugString()));
|
||||
const int32 top_padding = paddings_matrix(1, 0);
|
||||
const int32 bottom_padding = paddings_matrix(1, 1);
|
||||
const int32 left_padding = paddings_matrix(2, 0);
|
||||
const int32 right_padding = paddings_matrix(2, 1);
|
||||
|
||||
// Input filter is of the following dimensions:
|
||||
// [ filter_rows, filter_cols, in_depth, out_depth]
|
||||
const Tensor& filter = context->input(3);
|
||||
|
||||
// For 2D convolution, there should be 4 dimensions.
|
||||
OP_REQUIRES(context, padded_shape.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
padded_shape.DebugString()));
|
||||
OP_REQUIRES(context, filter.dims() == 4,
|
||||
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||
filter.shape().DebugString()));
|
||||
|
||||
// We only check the first three dims, since the depth is accessed as an
|
||||
// int64 below.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("filter too large"));
|
||||
}
|
||||
|
||||
// The last dimension for input is in_depth. It must be the same as the
|
||||
// filter's in_depth.
|
||||
const int64 in_depth = padded_shape.dim_size(3);
|
||||
OP_REQUIRES(
|
||||
context, in_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
in_depth, " vs ", filter.dim_size(2)));
|
||||
|
||||
// The last dimension for filter is out_depth.
|
||||
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||
|
||||
// The second dimension for input is rows/height.
|
||||
// The first dimension for filter is rows/height.
|
||||
const int64 padded_rows_raw = padded_shape.dim_size(1);
|
||||
OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input rows too large"));
|
||||
const int padded_rows = static_cast<int>(padded_rows_raw);
|
||||
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||
const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
|
||||
|
||||
// The third dimension for input is columns/width.
|
||||
// The second dimension for filter is columns/width.
|
||||
const int64 padded_cols_raw = padded_shape.dim_size(2);
|
||||
OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input cols too large"));
|
||||
const int padded_cols = static_cast<int>(padded_cols_raw);
|
||||
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||
const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
|
||||
|
||||
// The first dimension for input is batch.
|
||||
const int64 batch_raw = padded_shape.dim_size(0);
|
||||
OP_REQUIRES(context,
|
||||
FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch is too large"));
|
||||
const int batch = static_cast<int>(batch_raw);
|
||||
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
|
||||
const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
|
||||
|
||||
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
|
||||
padding_, &out_rows, &pad_rows));
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
|
||||
padding_, &out_cols, &pad_cols));
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
|
||||
OP_REQUIRES(context, (out_shape.num_elements() > 0),
|
||||
errors::InvalidArgument("Output tensor can't be empty"));
|
||||
|
||||
// Output tensor is of the following dimensions:
|
||||
// [ in_batch, out_rows, out_cols, out_depth ]
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
|
||||
VLOG(2) << "Conv2D: in_depth = " << in_depth
|
||||
<< ", padded_cols = " << padded_cols
|
||||
<< ", filter_cols = " << filter_cols
|
||||
<< ", padded_rows = " << padded_rows
|
||||
<< ", filter_rows = " << filter_rows
|
||||
<< ", stride_rows = " << stride_rows
|
||||
<< ", stride_cols = " << stride_cols
|
||||
<< ", out_depth = " << out_depth;
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (out_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
TConvFunctor conv_functor;
|
||||
conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
|
||||
padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
|
||||
filter_cols, out_depth, stride_rows, stride_cols, padding_,
|
||||
output->flat<T>().data(), out_rows, out_cols, st, top_padding,
|
||||
bottom_padding, left_padding, right_padding, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
bool align_corners_;
|
||||
int offset_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
|
||||
};
|
||||
|
||||
#define REGISTER_FUSED(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FusedResizeAndPadConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
FusedResizeConv2DUsingGemmOp< \
|
||||
T, \
|
||||
FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
|
||||
|
||||
TF_CALL_float(REGISTER_FUSED);
|
||||
|
||||
} // namespace tensorflow
|
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
@ -0,0 +1,240 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FusedResizePadConvOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void HandwrittenConv() {
|
||||
const int stride = 1;
|
||||
TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("resize_align_corners", false)
|
||||
.Attr("mode", "REFLECT")
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
// The image matrix is:
|
||||
// | 1 | 2 | 3 | 4 |
|
||||
// | 5 | 6 | 7 | 8 |
|
||||
// | 9 | 10 | 11 | 12 |
|
||||
Tensor image(DT_FLOAT,
|
||||
{image_batch_count, image_height, image_width, depth});
|
||||
test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
|
||||
// The filter matrix is:
|
||||
// | 1 | 4 | 7 |
|
||||
// | 2 | 5 | 8 |
|
||||
// | 3 | 6 | 9 |
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
|
||||
test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
|
||||
|
||||
const int resized_width = image_width;
|
||||
const int resized_height = image_height;
|
||||
|
||||
const int top_padding = 0;
|
||||
const int bottom_padding = 0;
|
||||
const int left_padding = 0;
|
||||
const int right_padding = 0;
|
||||
|
||||
AddInputFromArray<float>(image.shape(), image.flat<float>());
|
||||
AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
|
||||
AddInputFromArray<int32>(
|
||||
TensorShape({4, 2}),
|
||||
{0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
|
||||
AddInputFromArray<float>(filter.shape(), filter.flat<float>());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
|
||||
// the input set to zero because we're using the 'SAME' padding mode.
|
||||
// The calculations behind the expected output are:
|
||||
// (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
|
||||
// (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
|
||||
// (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
|
||||
// (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
|
||||
// (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
|
||||
// (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
|
||||
// (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
|
||||
// (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
|
||||
// (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
|
||||
// (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
|
||||
// (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
|
||||
// (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
|
||||
// This means we should end up with this matrix:
|
||||
// | 105 | 150 | 183 | 95 |
|
||||
// | 235 | 312 | 357 | 178 |
|
||||
// | 187 | 234 | 261 | 121 |
|
||||
const int expected_width = image_width;
|
||||
const int expected_height = image_height * filter_count;
|
||||
Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<float>(
|
||||
&expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
|
||||
const Tensor& output = *GetOutput(0);
|
||||
test::ExpectTensorNear<float>(expected, output, 1e-5);
|
||||
}
|
||||
|
||||
void CompareFusedAndSeparate(int input_width, int input_height,
|
||||
int input_depth, int resize_width,
|
||||
int resize_height, int y_padding, int x_padding,
|
||||
int filter_size, int filter_count,
|
||||
bool resize_align_corners, string pad_mode,
|
||||
int stride, string padding) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
const size_t input_data_size = input_height * input_width * input_depth;
|
||||
Tensor input_data(DT_FLOAT,
|
||||
TensorShape({1, input_height, input_width, input_depth}));
|
||||
for (int i = 0; i < input_data_size; ++i) {
|
||||
input_data.flat<float>()(i) = i + 1.0f;
|
||||
}
|
||||
Output input =
|
||||
Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
|
||||
const size_t filter_data_size =
|
||||
filter_size * filter_size * filter_count * input_depth;
|
||||
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
|
||||
input_depth, filter_count}));
|
||||
for (int i = 0; i < filter_data_size; ++i) {
|
||||
filter_data.flat<float>()(i) = i + 1.0f;
|
||||
}
|
||||
Output filter =
|
||||
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
|
||||
|
||||
Output resize_size =
|
||||
Const(root.WithOpName("resize_size"), {resize_height, resize_width});
|
||||
Output resize =
|
||||
ResizeBilinear(root.WithOpName("resize"), input, resize_size,
|
||||
ResizeBilinear::AlignCorners(resize_align_corners));
|
||||
Output paddings =
|
||||
Const(root.WithOpName("paddings"),
|
||||
{{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
|
||||
Output mirror_pad =
|
||||
MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
|
||||
Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
|
||||
{1, stride, stride, 1}, padding);
|
||||
|
||||
Output fused_conv = FusedResizeAndPadConv2D(
|
||||
root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
|
||||
pad_mode, {1, stride, stride, 1}, padding,
|
||||
FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
|
||||
|
||||
tensorflow::GraphDef graph;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
std::vector<Tensor> unfused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
|
||||
|
||||
std::vector<Tensor> fused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
|
||||
|
||||
test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||
"VALID");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -56,14 +56,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#define USE_ACCELERATE_GEMM
|
||||
#endif // __APPLE__
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
@ -189,87 +188,6 @@ class ReferenceConvFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
// A readable but slow implementation of matrix multiplication, useful for
|
||||
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
||||
// the Im2ColConvFunctor template definition inside the op registration to
|
||||
// enable. Assumes row-major ordering of the values in memory.
|
||||
template <class T1, class T2, class T3>
|
||||
class ReferenceGemmFunctor {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
const size_t a_i_stride = lda;
|
||||
const size_t a_l_stride = 1;
|
||||
const size_t b_j_stride = 1;
|
||||
const size_t b_l_stride = ldb;
|
||||
const size_t c_i_stride = ldc;
|
||||
const size_t c_j_stride = 1;
|
||||
size_t i, j, l;
|
||||
for (j = 0; j < n; j++) {
|
||||
for (i = 0; i < m; i++) {
|
||||
T3 total(0);
|
||||
for (l = 0; l < k; l++) {
|
||||
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
||||
const T1 a_value = a[a_index];
|
||||
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
||||
const T2 b_value = b[b_index];
|
||||
total += (a_value * b_value);
|
||||
}
|
||||
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
||||
c[c_index] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Uses the optimized Eigen library to implement the matrix multiplication
|
||||
// required by the Im2ColConvFunctor class. We supply the two input and one
|
||||
// output types so that the accumulator can potentially be higher-precision than
|
||||
// the inputs, even though we don't currently take advantage of this.
|
||||
template <class T1, class T2, class T3>
|
||||
class FastGemmFunctor {
|
||||
public:
|
||||
// Convenience wrappers for the Eigen matrix types we'll be using.
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT1;
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT2;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
MatrixT3;
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
ConstMatrixT1 a_matrix(a, m, k);
|
||||
ConstMatrixT2 b_matrix(b, k, n);
|
||||
MatrixT3 c_matrix(c, m, n);
|
||||
c_matrix.noalias() = a_matrix * b_matrix;
|
||||
}
|
||||
};
|
||||
|
||||
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
||||
// get a performance boost for float.
|
||||
#if defined(USE_ACCELERATE_GEMM)
|
||||
template <>
|
||||
class FastGemmFunctor<float, float, float> {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
||||
const float* b, size_t ldb, float* c, size_t ldc) {
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
||||
lda, b, ldb, 0.0f, c, ldc);
|
||||
}
|
||||
};
|
||||
#endif // USE_ACCELERATE_GEMM
|
||||
|
||||
// Used to keep track of persistent memory buffers used within the op.
|
||||
template <class T, size_t size>
|
||||
struct Im2ColBufferResource : public ResourceBase {
|
||||
mutex mu;
|
||||
T data[size];
|
||||
string DebugString() { return "Im2ColBufferResource"; }
|
||||
};
|
||||
|
||||
// Implements convolution as a two stage process, first packing the patches of
|
||||
// the input image into columns (im2col) and then running GEMM to produce the
|
||||
// final result.
|
||||
@ -344,7 +262,6 @@ class Im2ColConvFunctor {
|
||||
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||
const size_t patches_per_chunk =
|
||||
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||
|
||||
// Because memory allocation is very expensive on mobile platforms, try to
|
||||
// allocate a persistent buffer that will be kept around between calls. We
|
||||
// use TensorFlow's resource management to ensure that the memory will be
|
||||
|
@ -99,13 +99,15 @@ struct scalar_sqrt_gradient_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||
operator()(const T& output, const T& output_gradient) const {
|
||||
return static_cast<T>(0.5) * output_gradient / output;
|
||||
const T out_conj = numext::conj(output);
|
||||
return static_cast<T>(0.5) * output_gradient / out_conj;
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
||||
return pdiv(pmul(const_half, output_gradient), output);
|
||||
const Packet out_conj = pconj(output);
|
||||
return pdiv(pmul(const_half, output_gradient), out_conj);
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
@ -123,15 +125,17 @@ struct scalar_rsqrt_gradient_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||
operator()(const T& output, const T& output_gradient) const {
|
||||
return static_cast<T>(-0.5) * (output_gradient * output) *
|
||||
(output * output);
|
||||
const T out_conj = numext::conj(output);
|
||||
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
||||
(out_conj * out_conj);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
||||
return pmul(const_half,
|
||||
pmul(pmul(output_gradient, output), pmul(output, output)));
|
||||
const Packet out_conj = pconj(output);
|
||||
return pmul(const_half, pmul(pmul(output_gradient, out_conj),
|
||||
pmul(out_conj, out_conj)));
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
|
@ -35,21 +35,49 @@ class DrawBoundingBoxesOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& images = context->input(0);
|
||||
const Tensor& boxes = context->input(1);
|
||||
const int64 depth = images.dim_size(3);
|
||||
|
||||
OP_REQUIRES(context, images.dims() == 4,
|
||||
errors::InvalidArgument("The rank of the images should be 4"));
|
||||
OP_REQUIRES(
|
||||
context, boxes.dims() == 3,
|
||||
errors::InvalidArgument("The rank of the boxes tensor should be 3"));
|
||||
|
||||
OP_REQUIRES(context, images.dim_size(0) == boxes.dim_size(0),
|
||||
errors::InvalidArgument("The batch sizes should be the same"));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, depth == 4 || depth == 1 || depth == 3,
|
||||
errors::InvalidArgument("Channel depth should be either 1 (GRY), "
|
||||
"3 (RGB), or 4 (RGBA)"));
|
||||
|
||||
const int64 batch_size = images.dim_size(0);
|
||||
const int64 height = images.dim_size(1);
|
||||
const int64 width = images.dim_size(2);
|
||||
const int64 depth = images.dim_size(3);
|
||||
const int64 color_table_length = 10;
|
||||
|
||||
// 0: yellow
|
||||
// 1: blue
|
||||
// 2: red
|
||||
// 3: lime
|
||||
// 4: purple
|
||||
// 5: olive
|
||||
// 6: maroon
|
||||
// 7: navy blue
|
||||
// 8: aqua
|
||||
// 9: fuchsia
|
||||
float color_table[color_table_length][4] = {
|
||||
{1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1},
|
||||
{0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1},
|
||||
{0, 1, 1, 1}, {1, 0, 1, 1},
|
||||
};
|
||||
|
||||
// Reset first color channel to 1 if image is GRY.
|
||||
// For GRY images, this means all bounding boxes will be white.
|
||||
if (depth == 1) {
|
||||
for (int64 i = 0; i < color_table_length; i++) {
|
||||
color_table[i][0] = 1;
|
||||
}
|
||||
}
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
@ -62,8 +90,8 @@ class DrawBoundingBoxesOp : public OpKernel {
|
||||
for (int64 b = 0; b < batch_size; ++b) {
|
||||
const int64 num_boxes = boxes.dim_size(1);
|
||||
const auto tboxes = boxes.tensor<T, 3>();
|
||||
|
||||
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
||||
int64 color_index = bb % color_table_length;
|
||||
const int64 min_box_row =
|
||||
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
||||
const int64 min_box_row_clamp =
|
||||
@ -122,22 +150,34 @@ class DrawBoundingBoxesOp : public OpKernel {
|
||||
// Draw top line.
|
||||
if (min_box_row >= 0) {
|
||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||
canvas(b, min_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
for (int64 c = 0; c < depth; c++) {
|
||||
canvas(b, min_box_row, j, c) =
|
||||
static_cast<T>(color_table[color_index][c]);
|
||||
}
|
||||
}
|
||||
// Draw bottom line.
|
||||
if (max_box_row < height) {
|
||||
for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
|
||||
canvas(b, max_box_row, j, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
for (int64 c = 0; c < depth; c++) {
|
||||
canvas(b, max_box_row, j, c) =
|
||||
static_cast<T>(color_table[color_index][c]);
|
||||
}
|
||||
}
|
||||
// Draw left line.
|
||||
if (min_box_col >= 0) {
|
||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||
canvas(b, i, min_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
for (int64 c = 0; c < depth; c++) {
|
||||
canvas(b, i, min_box_col, c) =
|
||||
static_cast<T>(color_table[color_index][c]);
|
||||
}
|
||||
}
|
||||
// Draw right line.
|
||||
if (max_box_col < width) {
|
||||
for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
|
||||
canvas(b, i, max_box_col, 0) = Eigen::NumTraits<T>::quiet_NaN();
|
||||
for (int64 c = 0; c < depth; c++) {
|
||||
canvas(b, i, max_box_col, c) =
|
||||
static_cast<T>(color_table[color_index][c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ class GatherNdOp : public OpKernel {
|
||||
"index innermost dimension length must be <= params rank; saw: ",
|
||||
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
|
||||
|
||||
TensorShape indices_shape(indices.shape());
|
||||
const TensorShape& indices_shape(indices.shape());
|
||||
const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
|
||||
// Check that we have enough index space
|
||||
@ -79,7 +79,7 @@ class GatherNdOp : public OpKernel {
|
||||
N_result *= indices_shape.dim_size(i);
|
||||
}
|
||||
|
||||
TensorShape params_shape(params.shape());
|
||||
const TensorShape& params_shape(params.shape());
|
||||
Index total_nd = params_shape.dims();
|
||||
|
||||
TensorShape result_shape(indices_shape);
|
||||
|
105
tensorflow/core/kernels/gemm_functors.h
Normal file
105
tensorflow/core/kernels/gemm_functors.h
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is a set of different implementations for the basic matrix by matrix
|
||||
// multiply function, commonly known as GEMM after the BLAS library's naming.
|
||||
// Having a standard interface enables us to swap out implementations on
|
||||
// different platforms, to make sure we're using the optimal version. They are
|
||||
// implemented as C++ template functors, so they're easy to swap into all of the
|
||||
// different kernels that use them.
|
||||
|
||||
#include <string.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#define USE_ACCELERATE_GEMM
|
||||
#endif // __APPLE__
|
||||
|
||||
// A readable but slow implementation of matrix multiplication, useful for
|
||||
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
||||
// the Im2ColConvFunctor template definition inside the op registration to
|
||||
// enable. Assumes row-major ordering of the values in memory.
|
||||
template <class T1, class T2, class T3>
|
||||
class ReferenceGemmFunctor {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
const size_t a_i_stride = lda;
|
||||
const size_t a_l_stride = 1;
|
||||
const size_t b_j_stride = 1;
|
||||
const size_t b_l_stride = ldb;
|
||||
const size_t c_i_stride = ldc;
|
||||
const size_t c_j_stride = 1;
|
||||
size_t i, j, l;
|
||||
for (j = 0; j < n; j++) {
|
||||
for (i = 0; i < m; i++) {
|
||||
T3 total(0);
|
||||
for (l = 0; l < k; l++) {
|
||||
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
||||
const T1 a_value = a[a_index];
|
||||
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
||||
const T2 b_value = b[b_index];
|
||||
total += (a_value * b_value);
|
||||
}
|
||||
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
||||
c[c_index] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Uses the optimized Eigen library to implement the matrix multiplication
|
||||
// required by the Im2ColConvFunctor class. We supply the two input and one
|
||||
// output types so that the accumulator can potentially be higher-precision than
|
||||
// the inputs, even though we don't currently take advantage of this.
|
||||
template <class T1, class T2, class T3>
|
||||
class FastGemmFunctor {
|
||||
public:
|
||||
// Convenience wrappers for the Eigen matrix types we'll be using.
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT1;
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT2;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
MatrixT3;
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
ConstMatrixT1 a_matrix(a, m, k);
|
||||
ConstMatrixT2 b_matrix(b, k, n);
|
||||
MatrixT3 c_matrix(c, m, n);
|
||||
c_matrix.noalias() = a_matrix * b_matrix;
|
||||
}
|
||||
};
|
||||
|
||||
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
||||
// get a performance boost for float.
|
||||
#if defined(USE_ACCELERATE_GEMM)
|
||||
template <>
|
||||
class FastGemmFunctor<float, float, float> {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
||||
const float* b, size_t ldb, float* c, size_t ldc) {
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
||||
lda, b, ldb, 0.0f, c, ldc);
|
||||
}
|
||||
};
|
||||
#endif // USE_ACCELERATE_GEMM
|
@ -49,12 +49,13 @@ struct ImageResizerState {
|
||||
explicit ImageResizerState(bool align_corners)
|
||||
: align_corners_(align_corners) {}
|
||||
|
||||
// ValidateAndCreateOutput checks the bounds on the input tensors
|
||||
// ValidateAndCalculateOutputSize checks the bounds on the input tensors
|
||||
// and requested size, sets up some of the resizing state such as the
|
||||
// height_scale and width_scale, and allocates the output.
|
||||
// height_scale and width_scale, and calculates the output size.
|
||||
// If any of these operations fails, it sets an error status in
|
||||
// the context, which the caller must check.
|
||||
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
||||
void ValidateAndCalculateOutputSize(OpKernelContext* context,
|
||||
const Tensor& input) {
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
@ -87,12 +88,18 @@ struct ImageResizerState {
|
||||
OP_REQUIRES(
|
||||
context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
|
||||
errors::InvalidArgument("input image must be of non-zero size"));
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||
}
|
||||
|
||||
// Calculates all the required variables, and allocates the output.
|
||||
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
||||
ValidateAndCalculateOutputSize(context, input);
|
||||
if (!context->status().ok()) return;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, TensorShape({input.dim_size(0), out_height,
|
||||
out_width, input.dim_size(3)}),
|
||||
&output));
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||
}
|
||||
|
||||
int64 batch_size;
|
||||
|
@ -272,7 +272,7 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
OP_REQUIRES(context, out_backprop.dims() == 4,
|
||||
errors::InvalidArgument("out_backprop must be 4-dimensional"));
|
||||
|
||||
TensorShape output_shape = tensor_in.shape();
|
||||
const TensorShape& output_shape = tensor_in.shape();
|
||||
|
||||
Tensor tensor_out_dup;
|
||||
OP_REQUIRES_OK(context,
|
||||
|
@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test {
|
||||
test::SetOutputAttrs(params_.get(), &attrs);
|
||||
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
|
||||
params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
|
||||
params_.get()->resource_manager = device_.get()->resource_manager();
|
||||
|
||||
context_.reset(new OpKernelContext(params_.get()));
|
||||
device_->Compute(kernel_.get(), context_.get());
|
||||
|
@ -34,6 +34,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -355,47 +365,23 @@ class RandomGammaOp : public OpKernel {
|
||||
// Several calculations can be done on a per-alpha basis.
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
if (alpha < 0.3) {
|
||||
// For very small alpha, we use the log-space algorithm proposed in
|
||||
// "Simulating from a gamma distribution with small shape parameter",
|
||||
// http://arxiv.org/abs/1302.1884
|
||||
const double lambda = 1 / alpha - 1;
|
||||
const double w = alpha / (M_E /* exp(1) */ * (1 - alpha));
|
||||
const double r = 1 / (1 + w);
|
||||
|
||||
// Compute the rest of the samples for the current alpha value.
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == double(1.0)) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % num_samples;
|
||||
sample_idx < num_samples && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
PhiloxRandom gen = rng;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
short uniform_remaining = 0;
|
||||
|
||||
// Keep trying until we don't reject a sample. In practice, we
|
||||
// expect a low rejection rate.
|
||||
while (true) {
|
||||
UNIFORM(u);
|
||||
double z;
|
||||
if (u <= r) {
|
||||
z = -log(u / r);
|
||||
} else {
|
||||
UNIFORM(v);
|
||||
z = log(v) / lambda;
|
||||
}
|
||||
double eta = z >= 0 ? exp(-z) : w * lambda * exp(lambda * z);
|
||||
UNIFORM(v);
|
||||
double h = exp(-z - exp(-z / alpha));
|
||||
if (h > eta * v) {
|
||||
samples_alpha_offset[sample_idx * num_alphas] =
|
||||
static_cast<T>(exp(-z / alpha));
|
||||
break;
|
||||
}
|
||||
} // while: true
|
||||
} // for: sample_idx
|
||||
} else { // so, alpha >= 0.3
|
||||
UNIFORM(u);
|
||||
const double res = -log(1.0 - u);
|
||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||
} // for (sample_idx)
|
||||
} else { // if alpha != 1.0
|
||||
// Transformation-rejection from pairs of uniform and normal random
|
||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||
//
|
||||
@ -454,7 +440,7 @@ class RandomGammaOp : public OpKernel {
|
||||
}
|
||||
} // while: true
|
||||
} // for: sample_idx
|
||||
} // if: alpha < 0.3
|
||||
} // if (alpha == 1.0)
|
||||
} // for: output_idx
|
||||
}; // DoWork
|
||||
#undef UNIFORM
|
||||
@ -463,9 +449,7 @@ class RandomGammaOp : public OpKernel {
|
||||
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
||||
// each = ~60.
|
||||
// All of this /0.95 due to the rejection possibility = ~85.
|
||||
// All of this * ~2 to incorporate possibility of the log/exp branch for
|
||||
// low-alpha. (1 log, 4 exp, 3/, 3*)
|
||||
static const int kElementCost = 170 + 2 * Normal::kElementCost +
|
||||
static const int kElementCost = 85 + 2 * Normal::kElementCost +
|
||||
Uniform::kElementCost +
|
||||
3 * PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
|
@ -62,10 +62,10 @@ Tensor MakeInput(const TensorShape& shape,
|
||||
TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
|
||||
const std::vector<string> tensor_names = {
|
||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||
"tensor_half"};
|
||||
"tensor_bool", "tensor_int", "tensor_float", "tensor_double",
|
||||
"tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
|
||||
"tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
|
||||
"tensor_half", "tensor_float_empty"};
|
||||
|
||||
// We first need to write a tensor using the save_op
|
||||
{
|
||||
@ -164,6 +164,11 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
return static_cast<Eigen::half>(x) / Eigen::half(5);
|
||||
});
|
||||
inputs.push_back({nullptr, &input_14});
|
||||
// Input #15 is a 2-d empty float tensor
|
||||
Tensor input_15 = MakeInput<float>(TensorShape({2, 0}), [](int x) -> float {
|
||||
return static_cast<float>(x) / 10;
|
||||
});
|
||||
inputs.push_back({nullptr, &input_15});
|
||||
OpKernelContext::Params params;
|
||||
params.device = device.get();
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
@ -341,6 +346,15 @@ TEST_F(RestoreOpTest, RestoreSimple) {
|
||||
output->flat<Eigen::half>()(i));
|
||||
}
|
||||
}
|
||||
// The 2-d empty float tensor
|
||||
{
|
||||
MakeRestoreOp(DT_FLOAT);
|
||||
(*mutable_input(1).tensor).scalar<string>()() = tensor_names[13];
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor* output = GetOutput(0);
|
||||
TensorShape expected({2, 0});
|
||||
EXPECT_TRUE(output->shape().IsSameSize(expected));
|
||||
}
|
||||
}
|
||||
|
||||
class RestoreSliceOpTest : public OpsTestBase {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user