diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index ccfc17da4c4..256d6aa8aa6 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -315,6 +315,7 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout_; + protected: // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 35e2f216d74..091c35df655 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -32,92 +32,7 @@ limitations under the License. namespace xla { namespace { -class SliceTest : public ClientLibraryTestBase { - protected: - template - void RunSliceTenToTwo() { - std::vector constant; - constant.reserve(10); - for (int i = 0; i < 10; ++i) { - constant.push_back(static_cast(i)); - } - - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(constant); - builder.Slice(original, {2}, {4}, {1}); - - const std::vector expected = {static_cast(2), - static_cast(3)}; - ComputeAndCompareR1(&builder, expected, {}); - } -}; - -XLA_TEST_F(SliceTest, SliceZeroToZeroF32) { - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1({}); - builder.Slice(original, {0}, {0}, {1}); - - ComputeAndCompareR1(&builder, {}, {}); -} - -XLA_TEST_F(SliceTest, SliceTenToZeroF32) { - ComputationBuilder builder(client_, TestName()); - std::vector constant(10, 0.3); - auto original = builder.ConstantR1(constant); - builder.Slice(original, {7}, {7}, {1}); - - ComputeAndCompareR1(&builder, {}, {}); -} - -TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo(); } - -TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo(); } - -TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo(); } - -TEST_F(SliceTest, SliceTenToTen) { - const std::vector values = {0.0, 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, 9.0}; - - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {0}, {10}, {1}); - - ComputeAndCompareR1(&builder, values, {}, ErrorSpec(0.000001)); -} - -TEST_F(SliceTest, SliceLastFourOf1024) { - std::vector values(1024); - std::iota(values.begin(), values.end(), 0.0); - - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {1024 - 4}, {1024}, {1}); - - const std::vector expected = {1020, 1021, 1022, 1023}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); -} - -// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on -// 2016-05-01. Also b/28508652 -TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { - std::vector values(4096); - std::iota(values.begin(), values.end(), 0.0); - - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {7}, {7 + 1024}, {1}); - - std::vector expected(1024); - std::iota(values.begin(), values.end(), 7.0); - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); -} +class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { ComputationBuilder builder(client_, TestName()); @@ -208,6 +123,70 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +struct R1Spec { + int64 input_dim0; + int64 slice_start; + int64 slice_limit; + int64 slice_stride; +}; + +// Parameterized test that generates R1 values, slices them according +// to the R1Spec, and compares the result with a computed version. +class SliceR1Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + template + void Run(const R1Spec& spec) { + std::vector input(spec.input_dim0); + std::iota(input.begin(), input.end(), NativeT()); + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(input); + builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, + {spec.slice_stride}); + + std::vector expected; + for (int i = spec.slice_start; i < spec.slice_limit; + i += spec.slice_stride) { + expected.push_back(i); + } + + ComputeAndCompareR1(&builder, expected, {}); + } +}; + +XLA_TEST_P(SliceR1Test, DoIt) { + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); +} + +INSTANTIATE_TEST_CASE_P( // + SliceR1TestInstantiation, // + SliceR1Test, // + ::testing::Values( // + R1Spec{10, 0, 0, 1}, // + R1Spec{10, 7, 7, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 0, 10, 1}, // + R1Spec{1024, 1024 - 4, 1024, 1}, // + R1Spec{4096, 7, 7 + 1024, 1}, // + R1Spec{10, 0, 10, 2}, // + R1Spec{10, 0, 10, 3}, // + R1Spec{10, 0, 10, 4}, // + R1Spec{10, 0, 10, 5}, // + R1Spec{10, 0, 10, 10} // + ) // +); + struct R2Spec { int64 input_dim0; int64 input_dim1; @@ -222,13 +201,13 @@ struct R2Spec { class SliceR2Test : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; -TEST_P(SliceR2Test, DoIt) { +XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(input); + auto a = builder.ConstantR2FromArray2DWithLayout(input, spec.layout); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); std::unique_ptr> expected = ReferenceUtil::Slice2D( @@ -257,6 +236,18 @@ INSTANTIATE_TEST_CASE_P( R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, LayoutUtil::MakeLayout({1, 0})} ) ); diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 77358d12533..4041427dca6 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -515,7 +515,10 @@ py_test( size = "small", srcs = ["python/keras/utils/data_utils_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = [ + "noasan", # times out + "notsan", + ], deps = [ ":keras", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 9e5282595e3..01d769c80df 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -587,6 +587,7 @@ py_test( name = "head_test", size = "medium", srcs = ["python/learn/estimators/head_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops.py b/tensorflow/contrib/nn/python/ops/sampling_ops.py index 7a9eed511bd..2ae529e0155 100644 --- a/tensorflow/contrib/nn/python/ops/sampling_ops.py +++ b/tensorflow/contrib/nn/python/ops/sampling_ops.py @@ -105,7 +105,6 @@ def _rank_resample(weights, biases, inputs, sampled_values, num_resampled, return resampled, true_expected_count, resampled_expected_count -# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org. def rank_sampled_softmax_loss(weights, biases, labels, diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 7fa05bf39ba..706de58a0af 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -21,14 +21,31 @@ py_library( ], ) +cuda_py_tests( + name = "reconstruction_ops_test", + srcs = ["python/kernel_tests/reconstruction_ops_test.py"], + additional_deps = [ + ":signal_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "shape_ops_test", - size = "small", srcs = ["python/kernel_tests/shape_ops_test.py"], additional_deps = [ ":signal_py", "//third_party/py/numpy", "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index b4be2f7b4cc..d0f6c1d0c6b 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -14,16 +14,21 @@ # ============================================================================== """##Signal ops. -@@frames +@@frame @@hamming_window @@hann_window +@@overlap_and_add """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.signal.python.ops.shape_ops import frames +from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add +from tensorflow.contrib.signal.python.ops.shape_ops import frame +# `frame` used to be named `frames`, which is a noun and not a verb. +# Keep an alias to `frames` for backwards compatibility. +from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames from tensorflow.contrib.signal.python.ops.window_ops import hamming_window from tensorflow.contrib.signal.python.ops.window_ops import hann_window diff --git a/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py new file mode 100644 index 00000000000..5c9b2ac5181 --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py @@ -0,0 +1,192 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reconstruction_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.signal.python.ops import reconstruction_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ReconstructionOpsTest(test.TestCase): + + def __init__(self, *args, **kwargs): + super(ReconstructionOpsTest, self).__init__(*args, **kwargs) + self.batch_size = 3 + self.frames = 3 + self.samples = 5 + + self.bases = np.array(range(2, 5)) + exponents = np.array(range(self.frames * self.samples)) + powers = np.power(self.bases[:, np.newaxis], exponents[np.newaxis, :]) + + self.powers = np.reshape(powers, [self.batch_size, self.frames, + self.samples]) + self.frame_hop = 2 + + # Hand computed example using powers of unique numbers: this is easily + # verified. + self.expected_string = ["1", "10", "100100", "1001000", "10010010000", + "100100000000", "1001000000000", "10000000000000", + "100000000000000"] + + def test_all_ones(self): + signal = constant_op.constant(np.ones((3, 5)), dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, 2) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + expected_output = np.array([1, 1, 2, 2, 3, 2, 2, 1, 1]) + + self.assertAllClose(output, expected_output) + + def test_simple(self): + def make_input(frame_length, num_frames=3): + """Generate a tensor of num_frames frames of frame_length.""" + return np.reshape(np.arange(1, num_frames * frame_length + 1), + (-1, frame_length)) + + # List of (signal, expected_result, frame_hop). + configurations = [ + # All hop lengths on a frame length of 2. + (make_input(2), [1, 5, 9, 6], 1), + (make_input(2), [1, 2, 3, 4, 5, 6], 2), + + # All hop lengths on a frame length of 3. + (make_input(3), [1, 6, 15, 14, 9], 1), + (make_input(3), [1, 2, 7, 5, 13, 8, 9], 2), + (make_input(3), [1, 2, 3, 4, 5, 6, 7, 8, 9], 3), + + # All hop lengths on a frame length of 4. + (make_input(4), [1, 7, 18, 21, 19, 12], 1), + (make_input(4), [1, 2, 8, 10, 16, 18, 11, 12], 2), + (make_input(4), [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3), + (make_input(4), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4), + ] + + with self.test_session(use_gpu=True): + for signal, expected, frame_hop in configurations: + reconstruction = reconstruction_ops.overlap_and_add( + np.array(signal), frame_hop).eval() + expected_output = np.array(expected) + self.assertAllClose(reconstruction, expected_output) + + def test_powers(self): + signal = constant_op.constant(np.squeeze(self.powers[0, :, :]), + dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + string_output = [np.base_repr(x, self.bases[0]) for x in output] + + self.assertEqual(string_output, self.expected_string) + + def test_batch(self): + signal = constant_op.constant(self.powers, dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + accumulator = True + for i in range(self.batch_size): + string_output = [np.base_repr(x, self.bases[i]) for x in output[i, :]] + accumulator = accumulator and (string_output == self.expected_string) + + self.assertTrue(accumulator) + + def test_one_element_batch(self): + input_matrix = np.squeeze(self.powers[0, :, :]) + input_matrix = input_matrix[np.newaxis, :, :].astype(float) + signal = constant_op.constant(input_matrix, dtype=dtypes.float32) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + string_output = [np.base_repr(int(x), self.bases[0]) for x in + np.squeeze(output)] + + self.assertEqual(output.shape, (1, 9)) + self.assertEqual(string_output, self.expected_string) + + def test_gradient(self): + configurations = [ + ((1, 128), 1), + ((5, 35), 17), + ((10, 128), 128), + ((2, 10, 128), 127), + ((2, 2, 10, 128), 126), + ((2, 2, 2, 10, 128), 125), + ] + + for shape, frame_hop in configurations: + with self.test_session(use_gpu=True) as sess: + signal = array_ops.zeros(shape) + reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop) + loss = math_ops.reduce_sum(reconstruction) + # Increasing any sample in the input frames by one will increase the sum + # of all the samples in the reconstruction by 1, so the gradient should + # be all ones, no matter the shape or hop. + gradient = sess.run(gradients_impl.gradients([loss], [signal])[0]) + self.assertTrue((gradient == 1.0).all()) + + def test_gradient_batch(self): + with self.test_session(use_gpu=True) as sess: + signal = array_ops.zeros((2, 10, 10)) + frame_hop = 10 + reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop) + + # Multiply the first batch-item's reconstruction by zeros. This will block + # gradient from flowing into the first batch item from the loss. Multiply + # the second batch item by the integers from 0 to 99. Since there is zero + # overlap, the gradient for this batch item will be 0-99 shaped as (10, + # 10). + reconstruction *= array_ops.stack( + [array_ops.zeros((100,)), math_ops.to_float(math_ops.range(100))]) + loss = math_ops.reduce_sum(reconstruction) + + # Verify that only the second batch item receives gradient. + gradient = sess.run(gradients_impl.gradients([loss], [signal])[0]) + expected_gradient = np.stack([ + np.zeros((10, 10)), + np.reshape(np.arange(100).astype(np.float32), (10, 10))]) + self.assertAllEqual(expected_gradient, gradient) + + def test_gradient_numerical(self): + with self.test_session(use_gpu=True): + shape = (2, 10, 10) + framed_signal = array_ops.zeros(shape) + frame_hop = 10 + reconstruction = reconstruction_ops.overlap_and_add( + framed_signal, frame_hop) + error = test.compute_gradient_error( + framed_signal, shape, reconstruction, [2, 100]) + self.assertLess(error, 2e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index e07942875fd..8633ced599f 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -24,18 +24,18 @@ from tensorflow.contrib.signal.python.ops import shape_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FramesTest(test.TestCase): +class FrameTest(test.TestCase): def test_mapping_of_indices_without_padding(self): - with self.test_session(): + with self.test_session(use_gpu=True): tensor = constant_op.constant(np.arange(9152), dtypes.int32) tensor = array_ops.expand_dims(tensor, 0) - result = shape_ops.frames(tensor, 512, 180) - result = result.eval() + result = shape_ops.frame(tensor, 512, 180, pad_end=False).eval() expected = np.tile(np.arange(512), (49, 1)) expected += np.tile(np.arange(49) * 180, (512, 1)).T @@ -46,15 +46,14 @@ class FramesTest(test.TestCase): self.assertAllEqual(expected, result) def test_mapping_of_indices_with_padding(self): - with self.test_session(): + with self.test_session(use_gpu=True): tensor = constant_op.constant(np.arange(10000), dtypes.int32) tensor = array_ops.expand_dims(tensor, 0) - result = shape_ops.frames(tensor, 512, 192) - result = result.eval() + result = shape_ops.frame(tensor, 512, 192, pad_end=True).eval() - expected = np.tile(np.arange(512), (51, 1)) - expected += np.tile(np.arange(51) * 192, (512, 1)).T + expected = np.tile(np.arange(512), (53, 1)) + expected += np.tile(np.arange(53) * 192, (512, 1)).T expected[expected >= 10000] = 0 @@ -63,6 +62,277 @@ class FramesTest(test.TestCase): self.assertAllEqual(expected, result) + def test_invalid_inputs(self): + # Rank 0 input signal. + with self.assertRaises(ValueError): + shape_ops.frame(1, 1, 1) + + # If the rank is unknown, do not raise an exception. + shape_ops.frame(array_ops.placeholder(dtypes.float32), 1, 1) + + # Non-scalar frame_length. + with self.assertRaises(ValueError): + shape_ops.frame([1], [1], 1) + + # Non-scalar frame_step. + with self.assertRaises(ValueError): + shape_ops.frame([1], 1, [1]) + + # Non-scalar pad_value. + with self.assertRaises(ValueError): + shape_ops.frame([1], 1, 1, pad_end=True, pad_value=[1]) + + def test_length_zero(self): + signal = constant_op.constant([], dtype=dtypes.float32) + frame_length = 2 + frame_step = 1 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertEqual((0, 2), result.shape) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((0, 2), result.shape) + + def test_shape_inference(self): + signal = array_ops.placeholder(dtypes.int32, shape=[1, 1]) + frame_length = 2 + frame_step = 1 + # Shape inference is able to detect the rank and inner-most dimension + # if frame_length is known at graph definition time. + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99) + self.assertEqual([1, 1, 2], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False) + self.assertEqual([1, 0, 2], result.shape.as_list()) + + # If frame_length is not known, rank and (known) outer and inner dimensions + # are inferred. + signal = array_ops.placeholder(dtypes.int32, shape=[1, 2, 3, 4]) + frame_length = array_ops.placeholder(dtypes.int32, shape=[]) + frame_step = 1 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99, axis=1) + self.assertEqual([1, None, None, 3, 4], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False, axis=1) + self.assertEqual([1, None, None, 3, 4], result.shape.as_list()) + + # If frame_length and inner-most dimension is known, rank, inner dimensions, + # and known outer dimensions are inferred. + signal = array_ops.placeholder(dtypes.int32, + shape=[None, 5, None, 20, 5, 3]) + frame_length = 4 + frame_step = 3 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99, axis=3) + self.assertEqual([None, 5, None, 7, 4, 5, 3], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False, axis=3) + self.assertEqual([None, 5, None, 6, 4, 5, 3], result.shape.as_list()) + + # Test that shape inference is consistent with actual returned shapes for + # small values of signal_length, frame_length, frame_step, and pad_end in + # [True, False]. + frame_step = 1 + for signal_length in range(2): + signal = [0] * signal_length + for frame_length in range(2): + for pad_end in [False, True]: + op = shape_ops.frame(signal, frame_length, frame_step, + pad_end=pad_end, pad_value=99) + with self.test_session(use_gpu=True): + result = op.eval() + self.assertEqual(op.shape.as_list(), list(result.shape)) + + def test_basic_mono(self): + signal = np.arange(6) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + for rank in range(5): + nd_signal = np.reshape(signal, (1,) * rank + signal.shape) + + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4], [4, 5, 99]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + self.assertAllEqual(expected, result) + + # Without padding, we drop the last frame. + expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=False).eval() + self.assertAllEqual(expected, result) + + def test_basic_stereo(self): + signal = np.vstack([np.arange(6), + np.arange(6) + 10]) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + for rank in range(5): + nd_signal = np.reshape(signal, (1,) * rank + signal.shape) + + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + expected_inner_frames = np.array([ + [[0, 1, 2], [2, 3, 4], [4, 5, 99]], + [[10, 11, 12], [12, 13, 14], [14, 15, 99]]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + self.assertAllEqual(expected, result) + + # Without padding, we drop the last frame. + expected_inner_frames = np.array([[[0, 1, 2], [2, 3, 4]], + [[10, 11, 12], [12, 13, 14]]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=False).eval() + self.assertAllEqual(expected, result) + + def test_complex_shape(self): + signal = np.vstack([np.arange(6), + np.arange(6) + 10, + np.arange(6) + 20, + np.arange(6) + 30, + np.arange(6) + 40, + np.arange(6) + 50]) + signal = np.reshape(signal, (2, 1, 3, 1, 6)) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + # Resulting shape is (2, 1, 3, 1, 3, 3). + expected = [[[[[[0, 1, 2], [2, 3, 4], [4, 5, 99]]], + [[[10, 11, 12], [12, 13, 14], [14, 15, 99]]], + [[[20, 21, 22], [22, 23, 24], [24, 25, 99]]]]], + [[[[[30, 31, 32], [32, 33, 34], [34, 35, 99]]], + [[[40, 41, 42], [42, 43, 44], [44, 45, 99]]], + [[[50, 51, 52], [52, 53, 54], [54, 55, 99]]]]]] + self.assertAllEqual(expected, result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + # Resulting shape is (2, 1, 3, 1, 3, 2). + expected = [[[[[[0, 1, 2], [2, 3, 4]]], + [[[10, 11, 12], [12, 13, 14]]], + [[[20, 21, 22], [22, 23, 24]]]]], + [[[[[30, 31, 32], [32, 33, 34]]], + [[[40, 41, 42], [42, 43, 44]]], + [[[50, 51, 52], [52, 53, 54]]]]]] + self.assertAllEqual(expected, result) + + def test_axis(self): + signal = np.reshape(np.arange(16), (2, 4, 2)) + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length=2, frame_step=2, + pad_end=True, axis=1) + expected = np.reshape(np.arange(16), (2, 2, 2, 2)) + self.assertAllEqual(expected, result.eval()) + + result = shape_ops.frame(signal, frame_length=2, frame_step=1, + pad_end=True, axis=1) + expected = [[[[0, 1], [2, 3]], + [[2, 3], [4, 5]], + [[4, 5], [6, 7]], + [[6, 7], [0, 0]]], + [[[8, 9], [10, 11]], + [[10, 11], [12, 13]], + [[12, 13], [14, 15]], + [[14, 15], [0, 0]]]] + self.assertAllEqual(expected, result.eval()) + + result = shape_ops.frame(signal, frame_length=3, frame_step=1, + pad_end=True, axis=1) + expected = [[[[0, 1], [2, 3], [4, 5]], + [[2, 3], [4, 5], [6, 7]], + [[4, 5], [6, 7], [0, 0]], + [[6, 7], [0, 0], [0, 0]]], + [[[8, 9], [10, 11], [12, 13]], + [[10, 11], [12, 13], [14, 15]], + [[12, 13], [14, 15], [0, 0]], + [[14, 15], [0, 0], [0, 0]]]] + self.assertAllEqual(expected, result.eval()) + + def test_window_larger_than_signal(self): + signal = constant_op.constant([[1, 2], [11, 12]], dtype=dtypes.float32) + frame_length = 4 + frame_step = 1 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertAllClose([[[1, 2, 99, 99], [2, 99, 99, 99]], + [[11, 12, 99, 99], [12, 99, 99, 99]]], result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((2, 0, 4), result.shape) + + frame_step = 2 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertAllClose([[[1, 2, 99, 99]], [[11, 12, 99, 99]]], result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((2, 0, 4), result.shape) + + def test_preserves_type(self): + signal = math_ops.range(10, dtype=dtypes.float64) + frame_length = 2 + frame_step = 3 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step) + self.assertEqual(result.dtype, signal.dtype) + + def test_dynamic_tensor(self): + # Show that frame works even when the dimensions of its input are + # not known at graph creation time. + input_signal = np.vstack([np.arange(4), np.arange(4) + 10, + np.arange(4) + 20]) + frame_length = 2 + frame_step = 2 + + with self.test_session(use_gpu=True) as sess: + signal_placeholder = array_ops.placeholder(shape=(None, None), + dtype=dtypes.float32) + result = sess.run(shape_ops.frame( + signal_placeholder, frame_length, frame_step), + feed_dict={signal_placeholder: input_signal}) + self.assertAllEqual([[[0, 1], [2, 3]], + [[10, 11], [12, 13]], + [[20, 21], [22, 23]]], result) + + def test_gradient_numerical(self): + with self.test_session(use_gpu=True): + signal_shape = (2, 128) + signal = array_ops.ones(signal_shape) + frame_length = 33 + frame_step = 9 + frames = shape_ops.frame(signal, frame_length, frame_step) + error = test.compute_gradient_error( + signal, signal_shape, frames, frames.shape.as_list()) + self.assertLess(error, 2e-5) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py new file mode 100644 index 00000000000..f5f443ad09c --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py @@ -0,0 +1,144 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Signal reconstruction via overlapped addition of frames.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.signal.python.ops import shape_ops +from tensorflow.contrib.signal.python.ops import util_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _shuffle_to_front(input_tensor, k): + """Shuffles the last `k` indices of `input_tensor` to the front. + + Transposes `input_tensor` to have the last `k` indices at the front. The input + may have arbitrary rank and unknown shape. + + Args: + input_tensor: A `Tensor` of arbitrary rank and unknown shape. + k: A scalar `Tensor` specifying how many indices to shuffle. + + Returns: + A tranposed version of `input_tensor` with `k` indices shuffled to the + front. + + Raises: + ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar. + """ + k = ops.convert_to_tensor(k, name="k") + k.shape.with_rank(0) + k_static = tensor_util.constant_value(k) + if k_static is not None: + input_tensor.shape.with_rank_at_least(k_static) + + rank = array_ops.rank(input_tensor) + outer_indices, inner_indices = array_ops.split(math_ops.range(rank), + [rank - k, k]) + permutation = array_ops.concat([inner_indices, outer_indices], 0) + + return array_ops.transpose(input_tensor, perm=permutation) + + +def overlap_and_add(signal, frame_step, name=None): + """Reconstructs a signal from a framed representation. + + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + + output_size = (frames - 1) * frame_step + frame_length + + Args: + signal: A [..., frames, frame_length] `Tensor`. All dimensions may be + unknown, and rank must be at least 2. + frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be + less than or equal to `frame_length`. + name: An optional name for the operation. + + Returns: + A `Tensor` with shape `[..., output_size]` containing the overlap-added + frames of `signal`'s inner-most two dimensions. + + Raises: + ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar + integer or `frame_step` is greater than `frame_length`. + """ + with ops.name_scope(name, "overlap_and_add", [signal, frame_step]): + signal = ops.convert_to_tensor(signal, name="signal") + signal.shape.with_rank_at_least(2) + frame_step = ops.convert_to_tensor(frame_step, name="frame_step") + frame_step.shape.assert_has_rank(0) + if not frame_step.dtype.is_integer: + raise ValueError("frame_step must be an integer. Got %s" % + frame_step.dtype) + + # If frame_length and frame_step are known at graph construction time, check + # frame_step is less than or equal to frame_length. + frame_step_static = tensor_util.constant_value(frame_step) + if (frame_step_static is not None and signal.shape.ndims is not None and + signal.shape[-1].value is not None and + frame_step_static > signal.shape[-1].value): + raise ValueError( + "frame_step (%d) must be less than or equal to frame_length (%d)" % ( + frame_step_static, signal.shape[-1].value)) + + signal_shape = array_ops.shape(signal) + + # All dimensions that are not part of the overlap-and-add. Can be empty for + # rank 2 inputs. + outer_dimensions = signal_shape[:-2] + + signal_rank = array_ops.rank(signal) + frames = signal_shape[-2] + frame_length = signal_shape[-1] + + subframe_length = util_ops.gcd(frame_length, frame_step) + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + # To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe" + # level, where a subframe is gcd(frame_length, frame_step). Reshape signal + # from [..., frames, frame_length] into [..., subframes, subframe_length]. + subframe_shape = array_ops.concat( + [outer_dimensions, [-1, subframe_length]], 0) + subframe_signal = array_ops.reshape(signal, subframe_shape) + + # Now we shuffle the last [subframes, subframe_length] dimensions to the + # front. + # TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can + # avoid this pair of transposes. + subframe_signal = _shuffle_to_front(subframe_signal, 2) + + # Use unsorted_segment_sum to add overlapping subframes together. + segment_ids = array_ops.reshape(shape_ops.frame( + math_ops.range(output_subframes), subframes_per_frame, subframe_step, + pad_end=False), [-1]) + result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids, + num_segments=output_subframes) + + # result is a [subframes, subframe_length, ...outer_dimensions] tensor. We + # return a [...outer_dimensions, output_size] tensor with a transpose and + # reshape. + result_shape = array_ops.concat([outer_dimensions, [output_size]], 0) + return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2), + result_shape) diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py index 4914f19be75..dc7a073242c 100644 --- a/tensorflow/contrib/signal/python/ops/shape_ops.py +++ b/tensorflow/contrib/signal/python/ops/shape_ops.py @@ -18,70 +18,173 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes + +from tensorflow.contrib.signal.python.ops import util_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -def frames(signal, frame_length, frame_step, name=None): - """Frame a signal into overlapping frames. +def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis): + """Infers the shape of the return value of `frame`.""" + frame_length = tensor_util.constant_value(frame_length) + frame_step = tensor_util.constant_value(frame_step) + axis = tensor_util.constant_value(axis) + if signal.shape.ndims is None: + return None + if axis is None: + return [None] * (signal.shape.ndims + 1) - May be used in front of spectral functions. + signal_shape = signal.shape.as_list() + num_frames = None + frame_axis = signal_shape[axis] + outer_dimensions = signal_shape[:axis] + inner_dimensions = signal_shape[axis:][1:] + if signal_shape and frame_axis is not None: + if frame_step and frame_length is not None: + if pad_end: + # Double negative is so that we round up. + num_frames = -(-frame_axis // frame_step) + else: + num_frames = (frame_axis - frame_length + frame_step) // frame_step + num_frames = max(0, num_frames) + return outer_dimensions + [num_frames, frame_length] + inner_dimensions + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, + name=None): + """Expands `signal`'s `axis` dimension into frames of `frame_length`. + + Slides a window of size `frame_length` over `signal`s `axis` dimension + with a stride of `frame_step`, replacing the `axis` dimension with + `[frames, frame_length]` frames. + + If `pad_end` is True, window positions that are past the end of the `axis` + dimension are padded with `pad_value` until the window moves fully past the + end of the dimension. Otherwise, only window positions that fully overlap the + `axis` dimension are produced. For example: ```python pcm = tf.placeholder(tf.float32, [None, 9152]) - frames = tf.contrib.signal.frames(pcm, 512, 180) + frames = tf.contrib.signal.frame(pcm, 512, 180) magspec = tf.abs(tf.spectral.rfft(frames, [512])) image = tf.expand_dims(magspec, 3) ``` Args: - signal: A `Tensor` of shape `[batch_size, signal_length]`. - frame_length: An `int32` or `int64` `Tensor`. The length of each frame. - frame_step: An `int32` or `int64` `Tensor`. The step between frames. - name: A name for the operation (optional). + signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions + may be unknown. Rank must be at least 1. + frame_length: The frame length in samples. An integer or scalar `Tensor`. + frame_step: The frame hop size in samples. An integer or scalar `Tensor`. + pad_end: Whether to pad the end of `signal` with `pad_value`. + pad_value: An optional scalar `Tensor` to use where the input signal + does not exist when `pad_end` is True. + axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to + the last axis. Supports negative values for indexing from the end. + name: An optional name for the operation. Returns: - A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`. + A `Tensor` of frames with shape `[..., frames, frame_length, ...]`. Raises: - ValueError: if signal does not have rank 2. + ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar. """ - with ops.name_scope(name, "frames", [signal, frame_length, frame_step]): + with ops.name_scope(name, "frame", [signal, frame_length, frame_step, + pad_value]): signal = ops.convert_to_tensor(signal, name="signal") frame_length = ops.convert_to_tensor(frame_length, name="frame_length") frame_step = ops.convert_to_tensor(frame_step, name="frame_step") + axis = ops.convert_to_tensor(axis, name="axis") - signal_rank = signal.shape.ndims + signal.shape.with_rank_at_least(1) + frame_length.shape.assert_has_rank(0) + frame_step.shape.assert_has_rank(0) + axis.shape.assert_has_rank(0) - if signal_rank != 2: - raise ValueError("expected signal to have rank 2 but was " + signal_rank) + result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end, + axis) - signal_length = array_ops.shape(signal)[1] + # Axis can be negative. Convert it to positive. + signal_rank = array_ops.rank(signal) + axis = math_ops.range(signal_rank)[axis] - num_frames = math_ops.ceil((signal_length - frame_length) / frame_step) - num_frames = 1 + math_ops.cast(num_frames, dtypes.int32) + signal_shape = array_ops.shape(signal) + outer_dimensions, length_samples, inner_dimensions = array_ops.split( + signal_shape, [axis, 1, signal_rank - 1 - axis]) + length_samples = array_ops.reshape(length_samples, []) + num_outer_dimensions = array_ops.size(outer_dimensions) + num_inner_dimensions = array_ops.size(inner_dimensions) - pad_length = (num_frames - 1) * frame_step + frame_length - pad_signal = array_ops.pad(signal, [[0, 0], [0, - pad_length - signal_length]]) + # If padding is requested, pad the input signal tensor with pad_value. + if pad_end: + pad_value = ops.convert_to_tensor(pad_value, signal.dtype) + pad_value.shape.assert_has_rank(0) - indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0) - indices_frames = array_ops.tile(indices_frame, [num_frames, 1]) + # Calculate number of frames, using double negatives to round up. + num_frames = -(-length_samples // frame_step) - indices_step = array_ops.expand_dims( - math_ops.range(num_frames) * frame_step, 1) - indices_steps = array_ops.tile(indices_step, [1, frame_length]) + # Pad the signal by up to frame_length samples based on how many samples + # are remaining starting from last_frame_position. + pad_samples = math_ops.maximum( + 0, frame_length + frame_step * (num_frames - 1) - length_samples) - indices = indices_frames + indices_steps + # Pad the inner dimension of signal by pad_samples. + paddings = array_ops.concat( + [array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype), + [[0, pad_samples]], + array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)], + 0) + signal = array_ops.pad(signal, paddings, constant_values=pad_value) - # TODO(androbin): remove `transpose` when `gather` gets `axis` support - pad_signal = array_ops.transpose(pad_signal) - signal_frames = array_ops.gather(pad_signal, indices) - signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1]) + signal_shape = array_ops.shape(signal) + length_samples = signal_shape[axis] + else: + num_frames = math_ops.maximum( + 0, 1 + (length_samples - frame_length) // frame_step) - return signal_frames + subframe_length = util_ops.gcd(frame_length, frame_step) + subframes_per_frame = frame_length // subframe_length + subframes_per_hop = frame_step // subframe_length + num_subframes = length_samples // subframe_length + + slice_shape = array_ops.concat([outer_dimensions, + [num_subframes * subframe_length], + inner_dimensions], 0) + subframe_shape = array_ops.concat([outer_dimensions, + [num_subframes, subframe_length], + inner_dimensions], 0) + subframes = array_ops.reshape(array_ops.strided_slice( + signal, array_ops.zeros_like(signal_shape), + slice_shape), subframe_shape) + + # frame_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate frame in subframes. For example: + # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]] + frame_selector = array_ops.reshape( + math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1]) + + # subframe_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate subframe within a frame. For example: + # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] + subframe_selector = array_ops.reshape( + math_ops.range(subframes_per_frame), [1, subframes_per_frame]) + + # Adding the 2 selector tensors together produces a [num_frames, + # subframes_per_frame] tensor of indices to use with tf.gather to select + # subframes from subframes. We then reshape the inner-most + # subframes_per_frame dimension to stitch the subframes together into + # frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]]. + selector = frame_selector + subframe_selector + + frames = array_ops.reshape( + array_ops.gather(subframes, selector, axis=axis), + array_ops.concat([outer_dimensions, [num_frames, frame_length], + inner_dimensions], 0)) + + if result_shape: + frames.set_shape(result_shape) + return frames diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py new file mode 100644 index 00000000000..eee829d799e --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/util_ops.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility ops shared across tf.contrib.signal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + + +def gcd(a, b, name=None): + """Returns the greatest common divisor via Euclid's algorithm. + + Args: + a: The dividend. A scalar integer `Tensor`. + b: The divisor. A scalar integer `Tensor`. + name: An optional name for the operation. + + Returns: + A scalar `Tensor` representing the greatest common divisor between `a` and + `b`. + + Raises: + ValueError: If `a` or `b` are not scalar integers. + """ + with ops.name_scope(name, 'gcd', [a, b]): + a = ops.convert_to_tensor(a) + b = ops.convert_to_tensor(b) + + a.shape.assert_has_rank(0) + b.shape.assert_has_rank(0) + + if not a.dtype.is_integer: + raise ValueError('a must be an integer type. Got: %s' % a.dtype) + if not b.dtype.is_integer: + raise ValueError('b must be an integer type. Got: %s' % b.dtype) + + cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b)) + body = lambda a, b: [b, math_ops.mod(a, b)] + a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False) + return a diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 81b4534f10e..63bfc1aef18 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -207,7 +207,8 @@ void ClassificationStats::AddExample( } void ClassificationStats::CheckPrune() { - if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) { + if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() || + weight_sum_ < prune_sample_epoch_ * prune_check_every_) { return; } ++prune_sample_epoch_; diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index e8a5b675469..dae58720edb 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1514,6 +1514,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.input_device_contexts = &input_device_contexts; params.input_alloc_attrs = &input_alloc_attrs; params.runner = &runner_; + params.stats_collector = stats_collector_; Status s; NodeExecStats* stats = nullptr; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index e3cc97c9461..94b9d33c0cf 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -261,6 +261,7 @@ class CallOp : public AsyncOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.step_container = ctx->step_container(); + opts.stats_collector = ctx->stats_collector(); opts.runner = ctx->runner(); std::vector args; args.reserve(ctx->num_inputs()); @@ -545,6 +546,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, // Inherit the step_id from the caller. exec_args.step_id = opts.step_id; exec_args.step_container = opts.step_container; + + exec_args.stats_collector = opts.stats_collector; exec_args.call_frame = frame; exec_args.cancellation_manager = opts.cancellation_manager; exec_args.runner = *opts.runner; diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 059f2be629b..aeb924a709c 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -38,6 +38,7 @@ class GraphDef; class OpKernel; class ResourceMgr; class ScopedStepContainer; +class StepStatsCollector; class Node; // FunctionDefHelper::Create is a convenient helper to construct a @@ -402,7 +403,8 @@ class FunctionLibraryRuntime { int64 step_id = 0; // Per-step container. - ScopedStepContainer* step_container; + ScopedStepContainer* step_container = nullptr; + StepStatsCollector* stats_collector = nullptr; std::function)>* runner = nullptr; }; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index bf52c53b886..f8f61df872e 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -72,6 +72,7 @@ class OpKernelContext; // declared below class OpRegistryInterface; class ResourceMgr; class ScopedStepContainer; +class StepStatsCollector; class OpKernel { public: @@ -551,6 +552,7 @@ class OpKernelContext { FunctionCallFrame* call_frame = nullptr; FunctionLibraryRuntime* function_library = nullptr; std::function)>* runner = nullptr; + StepStatsCollector* stats_collector = nullptr; // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; @@ -940,6 +942,9 @@ class OpKernelContext { std::function)>* runner() const { return params_->runner; } + StepStatsCollector* stats_collector() const { + return params_->stats_collector; + } // Shared resources accessible to this kernel. ResourceMgr* resource_manager() const { return params_->resource_manager; } diff --git a/tensorflow/core/kernels/concat_lib_gpu.cc b/tensorflow/core/kernels/concat_lib_gpu.cc index cd0414ef409..5159cdaa6ec 100644 --- a/tensorflow/core/kernels/concat_lib_gpu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu.cc @@ -115,6 +115,7 @@ void ConcatGPU( TF_CALL_GPU_NUMBER_TYPES(REGISTER); TF_CALL_complex64(REGISTER); TF_CALL_complex128(REGISTER); +TF_CALL_int64(REGISTER); REGISTER(bfloat16); #undef REGISTER diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc index d120c69c4f4..f971637d5db 100644 --- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc @@ -201,21 +201,25 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device, TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32); TF_CALL_complex64(REGISTER_GPUCONCAT32); TF_CALL_complex128(REGISTER_GPUCONCAT32); +TF_CALL_int64(REGISTER_GPUCONCAT32); REGISTER_GPUCONCAT32(bfloat16); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64); TF_CALL_complex64(REGISTER_GPUCONCAT64); TF_CALL_complex128(REGISTER_GPUCONCAT64); +TF_CALL_int64(REGISTER_GPUCONCAT64); REGISTER_GPUCONCAT64(bfloat16); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32); TF_CALL_complex64(REGISTER_GPU32); TF_CALL_complex128(REGISTER_GPU32); +TF_CALL_int64(REGISTER_GPU32); REGISTER_GPU32(bfloat16); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64); TF_CALL_complex64(REGISTER_GPU64); TF_CALL_complex128(REGISTER_GPU64); +TF_CALL_int64(REGISTER_GPU64); REGISTER_GPU64(bfloat16); #undef REGISTER_GPUCONCAT32 diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index e7848a7e260..01a744dc7ec 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -195,6 +195,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); REGISTER_GPU(bfloat16); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); #undef REGISTER_GPU // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index e0cc08f101c..2db844e410c 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -216,7 +216,7 @@ class SingleSequenceExampleParserOp : public OpKernel { TensorShapeUtils::IsScalar(context_dense_keys[di].shape()), errors::InvalidArgument( "Expected context_dense_keys[", di, - "] to be a vector, got shape: ", + "] to be a scalar, got shape: ", context_dense_keys[di].shape().DebugString())); context_dense_keys_t[di] = context_dense_keys[di].scalar()(); } @@ -225,7 +225,7 @@ class SingleSequenceExampleParserOp : public OpKernel { TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()), errors::InvalidArgument( "Expected context_sparse_keys[", di, - "] to be a vector, got shape: ", + "] to be a scalar, got shape: ", context_sparse_keys[di].shape().DebugString())); context_sparse_keys_t[di] = context_sparse_keys[di].scalar()(); } @@ -234,7 +234,7 @@ class SingleSequenceExampleParserOp : public OpKernel { ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()), errors::InvalidArgument( "Expected feature_list_dense_keys[", di, - "] to be a vector, got shape: ", + "] to be a scalar, got shape: ", feature_list_dense_keys[di].shape().DebugString())); feature_list_dense_keys_t[di] = feature_list_dense_keys[di].scalar()(); @@ -244,7 +244,7 @@ class SingleSequenceExampleParserOp : public OpKernel { ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()), errors::InvalidArgument( "Expected feature_list_sparse_keys[", di, - "] to be a vector, got shape: ", + "] to be a scalar, got shape: ", feature_list_sparse_keys[di].shape().DebugString())); feature_list_sparse_keys_t[di] = feature_list_sparse_keys[di].scalar()(); diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index edaa10761eb..75820e3106f 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -157,6 +157,7 @@ REGISTER_PACK(string); PackOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); #undef REGISTER_GPU // A special GPU kernel for int32. diff --git a/tensorflow/docs_src/get_started/export.md b/tensorflow/docs_src/get_started/export.md new file mode 100644 index 00000000000..8273030271e --- /dev/null +++ b/tensorflow/docs_src/get_started/export.md @@ -0,0 +1,298 @@ +# Exporting a Trained Model for Serving + +Once you have trained an `Estimator` model, you may want to create a service +from that model that takes requests and returns a result. You can run such a +service locally on your machine or deploy it scalably in the cloud. + +To prepare a trained Estimator for serving, you must export it in the standard +[`SavedModel`](https://www.tensorflow.org/code/tensorflow/python/saved_model/README.md) +format, which wraps the TensorFlow graph, the trained variable values, any +required assets, and metadata together in a hermetic package. + +In this tutorial, we will discuss how to: + +* Add graph nodes that accept and prepare inference requests +* Specify the output nodes and the corresponding [APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto) + that can be served (Classify, Regress, or Predict) +* Export your model to the `SavedModel` format +* Deploy the model in Google Cloud ML Engine and request predictions +* Serve the model from a local server and request predictions + + +## The exported graph and its signatures + +The export procedure assembles a new TensorFlow graph from two main components: +1) a Serving Input Receiver that defines the format of the inputs to be + accepted, and +2) the trained model itself. + +An exported `SavedModel` contains that combined graph packaged together with one +or more *signatures*. Like a function signature in any programming language, a +graph signature specifies the required inputs (arguments) and the +expected outputs (return values) of performing the computation. In the typical +case, a single signature is present, corresponding to the predictions that the +model has learned to make. + +The *input* portion of the signature is determined by the Serving Input +Receiver. To specify the inputs that your deployed model will accept, you must +provide a `serving_input_receiver_fn()` to `estimator.export_savedmodel()` (see +below). + +The *output* portion of the signature is determined by the model. For instance, +canned Estimators know the nature of the outputs they produce (e.g. whether the +output is a classification or a regression, and the type and shape of those +outputs). Custom Estimators must provide this information via `export_outputs` +(see [below](#specifying_the_outputs_of_a_custom_model)). + +> Note: A *multi-headed model* provides multiple signatures, each corresponding +> to a different "head", i.e. a set of predictions that can be made from the +> same inputs by executing a subgraph of the complete trained graph. The +> *output* portions of these signatures are determined by the model. + + +![Overview of exporting a SavedModel from Estimator](../images/export_savedmodel_overview.png) + +## Preparing serving inputs + +During training, an @{$input_fn$`input_fn()`} ingests data and prepares it for +use by the model. At serving time, similarly, a `serving_input_receiver_fn()` +accepts inference requests and prepares them for the model. The purpose of this +function is to add placeholders to the graph which the serving system will feed +with inference requests, as well as to add any additional ops needed to convert +data from the input format into the feature `Tensor`s expected by the model. +The function returns a @{tf.estimator.export.ServingInputReceiver} object, which +packages the placeholders and the resulting feature `Tensor`s together. + +A typical pattern is that inference requests arrive in the form of serialized +`tf.Example`s, so the `serving_input_receiver_fn()` creates a single string +placeholder to receive them. The `serving_input_receiver_fn()` is then also +responsible for parsing the `tf.Example`s by adding a @{tf.parse_example} op to +the graph. + +When writing such a `serving_input_receiver_fn()`, you must pass a parsing +specification to @{tf.parse_example} to tell the parser what feature names to +expect and how to map them to `Tensor`s. A parsing specification takes the +form of a dict from feature names to @{tf.FixedLenFeature}, @{tf.VarLenFeature}, +and @{tf.SparseFeature}. (Note this parsing specification should not include +any label or weight columns, since those will not be available at serving +time—in contrast to a parsing specification used in the `input_fn()` at +training time.) + +In combination, then: + +```py +feature_spec = {'foo': tf.FixedLenFeature(...), + 'bar': tf.VarLenFeature(...)} + +def serving_input_receiver_fn(): + """An input receiver that expects a serialized tf.Example.""" + serialized_tf_example = tf.placeholder(dtype=tf.string, + shape=[default_batch_size], + name='input_example_tensor') + receiver_tensors = {'examples': serialized_tf_example} + features = tf.parse_example(serialized_tf_example, feature_spec) + return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) +``` + +The @{tf.estimator.export.build_parsing_serving_input_receiver_fn} utility +function provides that input receiver for the common case. + +> Note: when training a model to be served using Google Cloud ML Engine (see +> below), the parsing step is not needed, because the model will receive raw +> feature data. This is also true when using the Predict API with a local +> server. + +Even if you require no parsing or other input processing—i.e., if the +serving system will feed feature `Tensor`s directly—you must still provide +a `serving_input_receiver_fn()` that creates placeholders for the feature +`Tensor`s and passes them through. The +@{tf.estimator.export.build_raw_serving_input_receiver_fn} utility provides for +this. + +If these utilities do not meet your needs, you are free to write your own +`serving_input_receiver_fn()`. One case where this may be needed is if your +training `input_fn()` incorporates some preprocessing logic that must be +recapitulated at serving time. To reduce the risk of training-serving skew, we +recommend encapsulating such processing in a function which is then called +from both `input_fn()` and `serving_input_receiver_fn()`. + + +## Performing the export + +To export your trained Estimator, call +@{tf.estimator.Estimator.export_savedmodel} with the export base path, together +with the `serving_input_receiver_fn`. + +```py +estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn) +``` + +This method builds a new graph by first calling the +`serving_input_receiver_fn()` to obtain feature `Tensor`s, and then calling +this `Estimator`'s `model_fn()` to generate the model graph based on those +features. It starts a fresh `Session`, and, by default, restores the most recent +checkpoint into it. (A different checkpoint may be passed, if needed.) +Finally it creates a timestamped export directory below the given +`export_dir_base` (i.e., `export_dir_base/`), and writes a +`SavedModel` into it containing a single `MetaGraphDef` saved from this +Session. + +> Note: there is currently no built-in mechanism to garbage-collect old exports, +> so successive exports will accumulate under `export_dir_base` unless deleted +> by some external means. + +## Specifying the outputs of a custom model + +When writing a custom `model_fn`, you must populate the `export_outputs` element +of the @{tf.estimator.EstimatorSpec} return value. This is a dict of +`{name: output}` describing the output signatures to be exported and used during +serving. + +In the usual case of making a single prediction, this dict contains +one element, and the `name` is immaterial. In a multi-headed model, each head +is represented by an entry in this dict. In this case the `name` is a string +of your choice that can be used to request a specific head at serving time. + +Each `output` value must be an `ExportOutput` object such as +@{tf.estimator.export.ClassificationOutput}, +@{tf.estimator.export.RegressionOutput}, or +@{tf.estimator.export.PredictOutput}. + +These output types map straightforwardly to the +[TensorFlow Serving APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto), +and so determine which request types will be honored. + +> Note: In the multi-headed case, a `SignatureDef` will be generated for each +> element of the `export_outputs` dict returned from the model_fn, named using +> the same keys. These signatures differ only in their outputs, as provided by +> the corresponding `ExportOutput` entry. The inputs are always those provided +> by the `serving_input_receiver_fn`. +> An inference request may specify the head by name. One head must be named +> using [`signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`](https://www.tensorflow.org/code/saved_model/signature_constants.py) +> indicating which signature will be served when an inference request does not +> specify one. + + +## Serving the exported model on Google Cloud ML Engine + +[Google Cloud ML Engine](https://cloud.google.com/ml-engine/) provides a fully +managed, scalable environment for serving your trained SavedModels to make +online or batch predictions. + +Please see [Deploying Models](https://cloud.google.com/ml-engine/docs/how-tos/deploying-models) +to learn how to deploy your SavedModel on Cloud ML Engine. + +> Note: Cloud ML Engine accepts inference requests in JSON, CSV, or TFRecords +> formats, depending on the circumstance. Parsing these formats is not the +> responsibility of the graph. Cloud ML Engine does the parsing for you, and +> feeds raw feature data directly into the graph. Thus, when targeting Cloud ML +> Engine, you should use a `serving_input_receiver_fn()` of the passthrough form +> that simply creates placeholders for each feature. + + +## Requesting predictions from Google Cloud ML Engine + +To learn how to request predictions from a model deployed in Cloud ML Engine, +please see: + +* [Prediction Basics](https://cloud.google.com/ml-engine/docs/concepts/prediction-overview) +* [Getting Online Predictions](https://cloud.google.com/ml-engine/docs/how-tos/online-predict) +* [Getting Batch Predictions](https://cloud.google.com/ml-engine/docs/how-tos/batch-predict) + + +## Serving the exported model locally + +For local deployment, you can serve your model using +@{$deploy/tfserve$Tensorflow Serving}, an open-source project that loads a +`SavedModel` and exposes it as a [gRPC](http://www.grpc.io/) service. + +First, [install TensorFlow Serving](https://tensorflow.github.io/serving/setup#prerequisites). + +Then build and run the local model server, substituting `$export_dir_base` with +the path to the `SavedModel` you exported above: + +```sh +bazel build //tensorflow_serving/model_servers:tensorflow_model_server +bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_base_path=$export_dir_base +``` + +Now you have a server listening for inference requests via gRPC on port 9000! + + +## Requesting predictions from a local server + +The server responds to gRPC requests according to the [PredictionService](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto#L15) +gRPC API service definition. (The nested protocol buffers are defined in +various [neighboring files](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis)). + +From the API service definition, the gRPC framework generates client libraries +in various languages providing remote access to the API. In a project using the +Bazel build tool, these libraries are built automatically and provided via +dependencies like these (using Python for example): + +```build + deps = [ + "//tensorflow_serving/apis:classification_proto_py_pb2", + "//tensorflow_serving/apis:regression_proto_py_pb2", + "//tensorflow_serving/apis:predict_proto_py_pb2", + "//tensorflow_serving/apis:prediction_service_proto_py_pb2" + ] +``` + +Python client code can then import the libraries thus: + +```py +from tensorflow_serving.apis import classification_pb2 +from tensorflow_serving.apis import regression_pb2 +from tensorflow_serving.apis import predict_pb2 +from tensorflow_serving.apis import prediction_service_pb2 +``` + +> Note: `prediction_service_pb2` defines the service as a whole and so +> is always required. However a typical client will need only one of +> `classification_pb2`, `regression_pb2`, and `predict_pb2`, depending on the +> type of requests being made. + +Sending a gRPC request is then accomplished by assembling a protocol buffer +containing the request data and passing it to the service stub. Note how the +request protocol buffer is created empty and then populated via the +[generated protocol buffer API](https://developers.google.com/protocol-buffers/docs/reference/python-generated). + +```py +from grpc.beta import implementations + +channel = implementations.insecure_channel(host, int(port)) +stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) + +request = classification_pb2.ClassificationRequest() +example = request.input.example_list.examples.add() +example.features.feature['x'].float_list.value.extend(image[0].astype(float)) + +result = stub.Classify(request, 10.0) # 10 secs timeout +``` + +The returned result in this example is a `ClassificationResponse` protocol +buffer. + +This is a skeletal example; please see the @{$deploy$Tensorflow Serving} +documentation and [examples](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/example) +for more details. + +> Note: `ClassificationRequest` and `RegressionRequest` contain a +> `tensorflow.serving.Input` protocol buffer, which in turn contains a list of +> `tensorflow.Example` protocol buffers. `PredictRequest`, by contrast, +> contains a mapping from feature names to values encoded via `TensorProto`. +> Correspondingly: When using the `Classify` and `Regress` APIs, TensorFlow +> Serving feeds serialized `tf.Example`s to the graph, so your +> `serving_input_receiver_fn()` should include a `tf.parse_example()` Op. +> When using the generic `Predict` API, however, TensorFlow Serving feeds raw +> feature data to the graph, so a passthrough `serving_input_receiver_fn()` +> should be used. + + + + + + diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md index 0d302ec3830..dd69408cba4 100644 --- a/tensorflow/docs_src/get_started/index.md +++ b/tensorflow/docs_src/get_started/index.md @@ -26,6 +26,8 @@ To learn about the high-level API, read the following guides: which takes you into a somewhat more sophisticated use of this API. * @{$get_started/monitors$Logging and Monitoring Basics with tf.contrib.learn}, which explains how to audit the progress of model training. + * @{$get_started/export$Exporting a Trained Model for Serving}, which shows + how to save a trained model in a form that is ready to deploy. TensorBoard is a utility to visualize different aspects of machine learning. The following guides explain how to use TensorBoard: diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index 0bb5b551555..aba4224dc62 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -141,6 +141,7 @@ class ConcatOpTest(test.TestCase): self._testRandom(dtypes.float32) self._testRandom(dtypes.int16) self._testRandom(dtypes.int32) + self._testRandom(dtypes.int64) self._testRandom(dtypes.bfloat16) self._testRandom(dtypes.complex64) self._testRandom(dtypes.complex128) diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py index afc0c38cacb..95ea3a90473 100644 --- a/tensorflow/python/kernel_tests/stack_op_test.py +++ b/tensorflow/python/kernel_tests/stack_op_test.py @@ -45,45 +45,61 @@ class StackOpTest(test.TestCase): np.random.seed(7) with self.test_session(use_gpu=True): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): - data = np.random.randn(*shape) - # Convert [data[0], data[1], ...] separately to tensorflow - # TODO(irving): Remove list() once we handle maps correctly - xs = list(map(constant_op.constant, data)) - # Pack back into a single tensorflow tensor - c = array_ops.stack(xs) - self.assertAllEqual(c.eval(), data) + for dtype in [np.float32, np.int32, np.int64]: + data = np.random.randn(*shape).astype(dtype) + # Convert [data[0], data[1], ...] separately to tensorflow + # TODO(irving): Remove list() once we handle maps correctly + xs = list(map(constant_op.constant, data)) + # Pack back into a single tensorflow tensor + c = array_ops.stack(xs) + self.assertAllEqual(c.eval(), data) + def testSimpleParallel(self): + np.random.seed(7) + with self.test_session(use_gpu=True): + for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape).astype(np.float32) + xs = list(map(constant_op.constant, data)) c = array_ops.parallel_stack(xs) self.assertAllEqual(c.eval(), data) def testConst(self): + np.random.seed(7) + with self.test_session(use_gpu=True): + for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): + for dtype in [np.float32, np.int32, np.int64]: + data = np.random.randn(*shape).astype(dtype) + # Pack back into a single tensorflow tensor directly using np array + c = array_ops.stack(data) + # This is implemented via a Const: + self.assertEqual(c.op.type, "Const") + self.assertAllEqual(c.eval(), data) + + # Python lists also work for 1-D case: + if len(shape) == 1: + data_list = list(data) + cl = array_ops.stack(data_list) + self.assertEqual(cl.op.type, "Const") + self.assertAllEqual(cl.eval(), data) + + # Verify that shape induction works with shapes produced via const stack + a = constant_op.constant([1, 2, 3, 4, 5, 6]) + b = array_ops.reshape(a, array_ops.stack([2, 3])) + self.assertAllEqual(b.get_shape(), [2, 3]) + + def testConstParallel(self): np.random.seed(7) with self.test_session(use_gpu=True): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): data = np.random.randn(*shape).astype(np.float32) - # Pack back into a single tensorflow tensor directly using np array - c = array_ops.stack(data) - # This is implemented via a Const: - self.assertEqual(c.op.type, "Const") - self.assertAllEqual(c.eval(), data) - - c = array_ops.parallel_stack(data) - self.assertAllEqual(c.eval(), data) - - # Python lists also work for 1-D case: if len(shape) == 1: data_list = list(data) - cl = array_ops.stack(data_list) - self.assertEqual(cl.op.type, "Const") - self.assertAllEqual(cl.eval(), data) - cl = array_ops.parallel_stack(data_list) self.assertAllEqual(cl.eval(), data) - # Verify that shape induction works with shapes produced via const stack - a = constant_op.constant([1, 2, 3, 4, 5, 6]) - b = array_ops.reshape(a, array_ops.stack([2, 3])) - self.assertAllEqual(b.get_shape(), [2, 3]) + data = np.random.randn(*shape).astype(np.float32) + c = array_ops.parallel_stack(data) + self.assertAllEqual(c.eval(), data) def testGradientsAxis0(self): np.random.seed(7) diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh index 8e364f7ffb7..87793dabc5a 100755 --- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -69,10 +69,13 @@ ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow # Do not run tests with "no_pip" tag. If running GPU tests, also do not run # tests with no_pip_gpu tag. -PIP_TEST_FILTER_TAG="-no_pip" +PIP_TEST_FILTER_TAG="-no_oss,-no_pip" if [[ ${IS_GPU} == "1" ]]; then PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}" fi +if [[ ${IS_MAC} == "1" ]]; then + PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}" +fi # Bazel flags we need for all tests: # define=no_tensorflow_py_deps=true, to skip all test dependencies.