Change: Clean up documentation for ReverseSequence Change: Updated several tensorflow operations to use 32bit indices on GPU. Change: Add attribute batch_dim to ReverseSequenceOp. Change: Fix error in convert_to_records.py. As reported in https://github.com/tensorflow/tensorflow/issues/370 by AlexUnderMicrocontRoll. Change: Update TensorBoard README. Change: Fixes to boolean flags reported in https://github.com/tensorflow/tensorflow/issues/379. Supports: --bool_flag=True --> True --bool_flag=False --> False --bool_flag=gibberish --> False --bool_flag --> True --nobool_flag --> False Fixes #379 Change: Update generated Op docs. Change: Enable local development of TensorBoard using gulp Also make tf-tensorboard a regular component rather than special case This is mostly effected by creating tfserve.js, which is a small server with clever routing to load from bower_components/ and components/ using the paths that work within google3. Workflow: `gulp serve` Change: Add a full working code example to the tensorboard and summaries tutorial Change: Fix seq2seq_test when running on GPU. The "proj_w" and "proj_b" variables were being created before the `test_session()`'s device function took effect, which pushed the placement algorithm into making an incorrect decision. Change: Add a sentence in TensorBoard README on how to serialize summary data to logs and provide link to the how-to tutorial on the TensorFlow website. Change: Add error-catching code if string_input_producer is supplied a null input. Before this change, it would die with an opaque shape error from inside the queue. This change catches (most) python null lists being passed directly in, and at runtime detects null tensors. Adds two tests for this to input_test.py Change: Speed up for models that use the same variable multiple times in the case where variables must be copied across devices: - Have Variables wrap the Variable op in an Identity op when converted to Tensor. This avoids multiple copies across devices if a variable is used multiple time in a computation. - Add Variable.mutable() to return the non-wrapped Variable op for used when assigning new values. - Add an as_ref parameter to convert_to_tensor() to allow code to specify if they plan to assign a new value to the result of the conversion. Make Variable return the result of Variable.mutable() when as_ref is True. - Make all ops that assign values to variables pass as_ref=True when converting their arguments. Change: Change to reduce critical section times in gpu_event_mgr.h: (1) Call stream->ThenRecordEvent outside the EventMgr critical section (2) Do memory deallocation outside the critical section Speeds up one configuration of ptb_word_lm from 2924 words per second (wps) to 3278 wps on my desktop machine with a Titan X. Change: Remove some colons that break the open source build ::tensorflow::StringPiece breaks for @raingo, see https://github.com/tensorflow/tensorflow/issues/358. tensorflow::StringPiece (without the leading colons) seems to fix the problem. Change: Added check that inputs to Operation is a list and make a defensive copy of the input. This is for cases where the input list is changed such as in _add_input. Change: Use standard names for TensorFlow dtypes in the tutorial. Change: Add tests for tensor inputs. Change: Fix build after declaring more types for ops Change: Switch to 32 bit indexing to speedup convolutions and concatenations. Change: Add convert_image op to convert between types for images (similar to OpenCV's cvtScale). Change: Make cast work between numeric types (bool, uint8, int16, int32, int64, float, double). Change: Padding input data for odd number of paddings, so we can use cudnn anyway. + Fix total padding computation when padding==VALID. + This CL makes the Googlenet benchmark run 5x faster. Change: Support IndexedSlices in ConcatGrad Change: * sampled softmax op uses one embedding lookup for positive and negative samples * float64 support for sampled softmax Change: Move RNN code out of models.rnn (without breaking existing code). The API may still undergo minor changes, until full documentation as added. Change: Changed to use per-step stacks for the accumulators used in while-loop gradient computation. This addresses the problem caused by using concat without sufficient static shape information. It should also improve performance as we avoided those expensive concats. Change: Update generated Op docs. Change: Improve error messages when the optimizer finds no variables to minimize or when none of the variables has gradients. Change: Say that -1 isn't just for flattening in reshape docs Also add scalar reshape (reshape(t, [])) as an example. This fixes https://github.com/tensorflow/tensorflow/issues/281. Change: This is a test. Base CL: 109118714
611 lines
24 KiB
Python
611 lines
24 KiB
Python
# Copyright 2015 Google Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Functional tests for reduction ops."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import tensorflow.python.platform
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.python.framework import tensor_shape
|
|
|
|
|
|
class SumReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.sum(np_ans, keepdims=keep_dims)
|
|
else:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
for ra in reduction_axes.ravel()[::-1]:
|
|
np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session(use_gpu=use_gpu):
|
|
tf_ans = tf.reduce_sum(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllClose(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
if reduction_axes is not None and np.shape(reduction_axes) == (1,):
|
|
# Test scalar reduction_axes argument
|
|
self._compareAll(x, reduction_axes[0])
|
|
self._compare(x, reduction_axes, False, use_gpu=True)
|
|
self._compare(x, reduction_axes, False, use_gpu=False)
|
|
self._compare(x, reduction_axes, True, use_gpu=True)
|
|
self._compare(x, reduction_axes, True, use_gpu=False)
|
|
|
|
def testFloatReduce1D(self):
|
|
# Create a 1D array of floats
|
|
np_arr = np.arange(1, 6).reshape([5]).astype(np.float32)
|
|
self._compareAll(np_arr, [0])
|
|
|
|
def testFloatReduce2D(self):
|
|
# Create a 2D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [0, 1])
|
|
|
|
def testFloatReduce3D(self):
|
|
# Create a 3D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testFloatReduce4D(self):
|
|
# Create a 4D array of floats and reduce across some
|
|
# dimensions
|
|
np_arr = np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
# Need specialization for reduce(4D, [0, 2])
|
|
# self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
self._compareAll(np_arr, [1, 2, 3])
|
|
self._compareAll(np_arr, [0, 1, 2, 3])
|
|
|
|
def testFloatReduce5D(self):
|
|
# Create a 5D array of floats and reduce across some dimensions
|
|
np_arr = np.arange(0, 840).reshape([2, 3, 5, 7, 4]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
# Need specialization for reduce(4D, [0, 2])
|
|
# self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
self._compareAll(np_arr, [1, 2, 3])
|
|
self._compareAll(np_arr, [0, 1, 2, 3])
|
|
self._compareAll(np_arr, [1, 2, 3, 4])
|
|
self._compareAll(np_arr, [0, 1, 2, 3, 4])
|
|
|
|
# Simple tests for various tf.
|
|
def testDoubleReduce1D(self):
|
|
np_arr = np.arange(1, 6).reshape([5]).astype(np.float64)
|
|
self._compare(np_arr, [], False)
|
|
self._compare(np_arr, [0], False)
|
|
|
|
def testInt32Reduce1D(self):
|
|
np_arr = np.arange(1, 6).reshape([5]).astype(np.int32)
|
|
self._compare(np_arr, [], False)
|
|
self._compare(np_arr, [0], False)
|
|
|
|
def testInvalidIndex(self):
|
|
np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
|
input_tensor = tf.convert_to_tensor(np_arr)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
|
tf.reduce_sum(input_tensor, [-1])
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
|
tf.reduce_sum(input_tensor, [2])
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
|
tf.reduce_sum(input_tensor, [0, 2])
|
|
|
|
# Int64??
|
|
|
|
def _compareGradient(self, shape, sum_shape, reduction_axes):
|
|
if reduction_axes is not None and np.shape(reduction_axes) == (1,):
|
|
# Test scalar reduction_axes argument
|
|
self._compareGradient(shape, sum_shape, reduction_axes[0])
|
|
x = np.arange(1.0, 49.0).reshape(shape).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_sum(t, reduction_axes)
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
shape,
|
|
su,
|
|
sum_shape,
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient(self):
|
|
self._compareGradient([2, 3, 4, 2], [2, 2], [1, 2])
|
|
|
|
def testGradient2(self):
|
|
self._compareGradient([2, 3, 4, 2], [2, 4, 2], [1])
|
|
|
|
def testGradient3(self):
|
|
self._compareGradient([2, 3, 4, 2], [2, 3, 2], [2])
|
|
|
|
def testGradient4(self):
|
|
self._compareGradient([2, 3, 4, 2], [], None)
|
|
|
|
|
|
class MeanReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims):
|
|
np_sum = x
|
|
count = 1
|
|
for ra in reduction_axes[::-1]:
|
|
np_sum = np.sum(np_sum, axis=ra, keepdims=keep_dims)
|
|
count *= x.shape[ra]
|
|
np_ans = np_sum / count
|
|
with self.test_session():
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_mean(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllClose(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False)
|
|
self._compare(x, reduction_axes, True)
|
|
|
|
def testFloatReduce3D(self):
|
|
# Create a 3D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testGradient(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float32)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_mean(t, [1, 2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
su = tf.reduce_mean(t, [0, 1, 2, 3])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[1],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
su = tf.reduce_mean(t, [])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 3, 4, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
class ProdReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.prod(np_ans, keepdims=keep_dims)
|
|
else:
|
|
for ra in reduction_axes[::-1]:
|
|
np_ans = np.prod(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session():
|
|
if reduction_axes is not None:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_prod(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllClose(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False)
|
|
self._compare(x, reduction_axes, True)
|
|
|
|
def testFloatReduce3D(self):
|
|
# Create a 3D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testGradient(self):
|
|
s = [2, 3, 4, 2]
|
|
# NOTE(kearnes): divide by 20 so product is a reasonable size
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float32) / 20.
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
|
|
su = tf.reduce_prod(t, [])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 3, 4, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
su = tf.reduce_prod(t, [1, 2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
su = tf.reduce_prod(t, [0, 1, 2, 3])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[1],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
|
|
|
# NOTE(kearnes): the current gradient calculation gives NaNs for 0 inputs
|
|
x = np.arange(0.0, 48.0).reshape(s).astype(np.float32) / 20.
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_prod(t, [])
|
|
jacob_t, _ = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 3, 4, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
with self.assertRaisesOpError("Tensor had NaN values"):
|
|
tf.check_numerics(jacob_t, message="_ProdGrad NaN test").op.run()
|
|
|
|
|
|
class MinReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.amin(np_ans, keepdims=keep_dims)
|
|
else:
|
|
for ra in reduction_axes[::-1]:
|
|
np_ans = np.amin(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session(use_gpu=use_gpu):
|
|
if reduction_axes is not None:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_min(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllClose(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False, use_gpu=True)
|
|
self._compare(x, reduction_axes, False, use_gpu=False)
|
|
self._compare(x, reduction_axes, True, use_gpu=True)
|
|
self._compare(x, reduction_axes, True, use_gpu=False)
|
|
|
|
def testFloatReduce3D(self):
|
|
# Create a 3D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testGradient(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_min(t, [1, 2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient2(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_min(t, [1])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 4, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient3(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_min(t, [2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 3, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient4(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_min(t)
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[1],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
|
|
class MaxReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.amax(np_ans, keepdims=keep_dims)
|
|
else:
|
|
for ra in reduction_axes[::-1]:
|
|
np_ans = np.amax(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session(use_gpu=use_gpu):
|
|
if reduction_axes is not None:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_max(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllClose(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False, use_gpu=True)
|
|
self._compare(x, reduction_axes, False, use_gpu=False)
|
|
self._compare(x, reduction_axes, True, use_gpu=True)
|
|
self._compare(x, reduction_axes, True, use_gpu=False)
|
|
|
|
def testFloatReduce3D(self):
|
|
# Create a 3D array of floats and reduce across all possible
|
|
# dimensions
|
|
np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testGradient(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_max(t, [1, 2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient2(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_max(t, [1])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 4, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient3(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_max(t, [2])
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[2, 3, 2],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
def testGradient4(self):
|
|
s = [2, 3, 4, 2]
|
|
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
|
|
with self.test_session():
|
|
t = tf.convert_to_tensor(x)
|
|
su = tf.reduce_max(t)
|
|
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
|
s,
|
|
su,
|
|
[1],
|
|
x_init_value=x,
|
|
delta=1)
|
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
|
|
|
|
|
|
class AllReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.all(np_ans, keepdims=keep_dims)
|
|
else:
|
|
for ra in reduction_axes[::-1]:
|
|
np_ans = np.all(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session(use_gpu=use_gpu):
|
|
if reduction_axes is not None:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_all(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllEqual(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False, use_gpu=True)
|
|
self._compare(x, reduction_axes, False, use_gpu=False)
|
|
self._compare(x, reduction_axes, True, use_gpu=True)
|
|
self._compare(x, reduction_axes, True, use_gpu=False)
|
|
|
|
def testAll3D(self):
|
|
# Create a 3D array of bools and reduce across all possible
|
|
# dimensions
|
|
np_arr = (np.random.uniform(0, 1, 30) > 0.1).reshape([2, 3, 5])
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
|
|
class AnyReductionTest(tf.test.TestCase):
|
|
|
|
def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
|
|
np_ans = x
|
|
if reduction_axes is None:
|
|
np_ans = np.any(np_ans, keepdims=keep_dims)
|
|
else:
|
|
for ra in reduction_axes[::-1]:
|
|
np_ans = np.any(np_ans, axis=ra, keepdims=keep_dims)
|
|
with self.test_session(use_gpu=use_gpu):
|
|
if reduction_axes is not None:
|
|
reduction_axes = np.array(reduction_axes).astype(np.int32)
|
|
tf_ans = tf.reduce_any(x, reduction_axes, keep_dims)
|
|
out = tf_ans.eval()
|
|
self.assertAllEqual(np_ans, out)
|
|
self.assertShapeEqual(np_ans, tf_ans)
|
|
|
|
def _compareAll(self, x, reduction_axes):
|
|
self._compare(x, reduction_axes, False, use_gpu=True)
|
|
self._compare(x, reduction_axes, False, use_gpu=False)
|
|
self._compare(x, reduction_axes, True, use_gpu=True)
|
|
self._compare(x, reduction_axes, True, use_gpu=False)
|
|
|
|
def testAll3D(self):
|
|
# Create a 3D array of bools and reduce across all possible
|
|
# dimensions
|
|
np_arr = (np.random.uniform(0, 1, 30) > 0.9).reshape([2, 3, 5])
|
|
self._compareAll(np_arr, None)
|
|
self._compareAll(np_arr, [])
|
|
self._compareAll(np_arr, [0])
|
|
self._compareAll(np_arr, [1])
|
|
self._compareAll(np_arr, [2])
|
|
self._compareAll(np_arr, [0, 1])
|
|
self._compareAll(np_arr, [1, 2])
|
|
self._compareAll(np_arr, [0, 2])
|
|
self._compareAll(np_arr, [0, 1, 2])
|
|
|
|
def testPartialShapes(self):
|
|
# Input shape is unknown.
|
|
c_unknown = tf.placeholder(tf.float32)
|
|
s_unknown = tf.reduce_sum(c_unknown, [1, 2])
|
|
self.assertEqual(tensor_shape.unknown_shape(), s_unknown.get_shape())
|
|
|
|
# Input shape only has known rank.
|
|
c_known_rank = tf.placeholder(tf.float32)
|
|
c_known_rank.set_shape(tensor_shape.unknown_shape(ndims=3))
|
|
s_known_rank = tf.reduce_sum(c_known_rank, [1, 2], keep_dims=True)
|
|
self.assertEqual(3, s_known_rank.get_shape().ndims)
|
|
|
|
# Reduction indices are unknown.
|
|
unknown_indices = tf.placeholder(tf.int32)
|
|
c_unknown_indices = tf.constant([[10.0], [20.0]])
|
|
s_unknown_indices = tf.reduce_sum(c_unknown_indices, unknown_indices,
|
|
keep_dims=False)
|
|
self.assertEqual(tensor_shape.unknown_shape(),
|
|
s_unknown_indices.get_shape())
|
|
s_unknown_indices_keep = tf.reduce_sum(c_unknown_indices, unknown_indices,
|
|
keep_dims=True)
|
|
self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tf.test.main()
|