Merge pull request #11508 from av8ramit/branch_162017464
Branch 162017464
This commit is contained in:
commit
b126091bbb
@ -315,6 +315,7 @@ class LayoutAssignment : public HloPassInterface {
|
||||
|
||||
ComputationLayout* entry_computation_layout_;
|
||||
|
||||
protected:
|
||||
// Map containing the layouts of all computations assigned so
|
||||
// far. Computations are handled in a topological sort where computations are
|
||||
// handled before their caller instructions so the layouts of caller
|
||||
|
@ -32,92 +32,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class SliceTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
template <typename NativeT>
|
||||
void RunSliceTenToTwo() {
|
||||
std::vector<NativeT> constant;
|
||||
constant.reserve(10);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
constant.push_back(static_cast<NativeT>(i));
|
||||
}
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<NativeT>(constant);
|
||||
builder.Slice(original, {2}, {4}, {1});
|
||||
|
||||
const std::vector<NativeT> expected = {static_cast<NativeT>(2),
|
||||
static_cast<NativeT>(3)};
|
||||
ComputeAndCompareR1<NativeT>(&builder, expected, {});
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>({});
|
||||
builder.Slice(original, {0}, {0}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<float> constant(10, 0.3);
|
||||
auto original = builder.ConstantR1<float>(constant);
|
||||
builder.Slice(original, {7}, {7}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo<float>(); }
|
||||
|
||||
XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo<double>(); }
|
||||
|
||||
TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo<uint32>(); }
|
||||
|
||||
TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo<int32>(); }
|
||||
|
||||
XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo<uint64>(); }
|
||||
|
||||
XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo<int64>(); }
|
||||
|
||||
TEST_F(SliceTest, SliceTenToTen) {
|
||||
const std::vector<float> values = {0.0, 1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0, 9.0};
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {0}, {10}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
TEST_F(SliceTest, SliceLastFourOf1024) {
|
||||
std::vector<float> values(1024);
|
||||
std::iota(values.begin(), values.end(), 0.0);
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {1024 - 4}, {1024}, {1});
|
||||
|
||||
const std::vector<float> expected = {1020, 1021, 1022, 1023};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on
|
||||
// 2016-05-01. Also b/28508652
|
||||
TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
|
||||
std::vector<float> values(4096);
|
||||
std::iota(values.begin(), values.end(), 0.0);
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {7}, {7 + 1024}, {1});
|
||||
|
||||
std::vector<float> expected(1024);
|
||||
std::iota(values.begin(), values.end(), 7.0);
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
class SliceTest : public ClientLibraryTestBase {};
|
||||
|
||||
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -208,6 +123,70 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
|
||||
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
struct R1Spec {
|
||||
int64 input_dim0;
|
||||
int64 slice_start;
|
||||
int64 slice_limit;
|
||||
int64 slice_stride;
|
||||
};
|
||||
|
||||
// Parameterized test that generates R1 values, slices them according
|
||||
// to the R1Spec, and compares the result with a computed version.
|
||||
class SliceR1Test : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<R1Spec> {
|
||||
protected:
|
||||
template <typename NativeT>
|
||||
void Run(const R1Spec& spec) {
|
||||
std::vector<NativeT> input(spec.input_dim0);
|
||||
std::iota(input.begin(), input.end(), NativeT());
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<NativeT>(input);
|
||||
builder.Slice(original, {spec.slice_start}, {spec.slice_limit},
|
||||
{spec.slice_stride});
|
||||
|
||||
std::vector<NativeT> expected;
|
||||
for (int i = spec.slice_start; i < spec.slice_limit;
|
||||
i += spec.slice_stride) {
|
||||
expected.push_back(i);
|
||||
}
|
||||
|
||||
ComputeAndCompareR1<NativeT>(&builder, expected, {});
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt) {
|
||||
Run<float>(GetParam());
|
||||
Run<double>(GetParam());
|
||||
Run<uint32>(GetParam());
|
||||
Run<int32>(GetParam());
|
||||
Run<uint64>(GetParam());
|
||||
Run<int64>(GetParam());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P( //
|
||||
SliceR1TestInstantiation, //
|
||||
SliceR1Test, //
|
||||
::testing::Values( //
|
||||
R1Spec{10, 0, 0, 1}, //
|
||||
R1Spec{10, 7, 7, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 2, 4, 1}, //
|
||||
R1Spec{10, 0, 10, 1}, //
|
||||
R1Spec{1024, 1024 - 4, 1024, 1}, //
|
||||
R1Spec{4096, 7, 7 + 1024, 1}, //
|
||||
R1Spec{10, 0, 10, 2}, //
|
||||
R1Spec{10, 0, 10, 3}, //
|
||||
R1Spec{10, 0, 10, 4}, //
|
||||
R1Spec{10, 0, 10, 5}, //
|
||||
R1Spec{10, 0, 10, 10} //
|
||||
) //
|
||||
);
|
||||
|
||||
struct R2Spec {
|
||||
int64 input_dim0;
|
||||
int64 input_dim1;
|
||||
@ -222,13 +201,13 @@ struct R2Spec {
|
||||
class SliceR2Test : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<R2Spec> {};
|
||||
|
||||
TEST_P(SliceR2Test, DoIt) {
|
||||
XLA_TEST_P(SliceR2Test, DoIt) {
|
||||
const R2Spec& spec = GetParam();
|
||||
Array2D<int32> input(spec.input_dim0, spec.input_dim1);
|
||||
input.FillUnique();
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR2FromArray2D<int32>(input);
|
||||
auto a = builder.ConstantR2FromArray2DWithLayout<int32>(input, spec.layout);
|
||||
builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
|
||||
|
||||
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
|
||||
@ -257,6 +236,18 @@ INSTANTIATE_TEST_CASE_P(
|
||||
R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}},
|
||||
LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}},
|
||||
LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}},
|
||||
LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}},
|
||||
LayoutUtil::MakeLayout({1, 0})}
|
||||
)
|
||||
);
|
||||
|
@ -515,7 +515,10 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/keras/utils/data_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"],
|
||||
tags = [
|
||||
"noasan", # times out
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -587,6 +587,7 @@ py_test(
|
||||
name = "head_test",
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/head_test.py"],
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["noasan"], # times out b/63678675
|
||||
deps = [
|
||||
|
@ -105,7 +105,6 @@ def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
|
||||
return resampled, true_expected_count, resampled_expected_count
|
||||
|
||||
|
||||
# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org.
|
||||
def rank_sampled_softmax_loss(weights,
|
||||
biases,
|
||||
labels,
|
||||
|
@ -21,14 +21,31 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "reconstruction_ops_test",
|
||||
srcs = ["python/kernel_tests/reconstruction_ops_test.py"],
|
||||
additional_deps = [
|
||||
":signal_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "shape_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/shape_ops_test.py"],
|
||||
additional_deps = [
|
||||
":signal_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
@ -14,16 +14,21 @@
|
||||
# ==============================================================================
|
||||
"""##Signal ops.
|
||||
|
||||
@@frames
|
||||
@@frame
|
||||
@@hamming_window
|
||||
@@hann_window
|
||||
@@overlap_and_add
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.signal.python.ops.shape_ops import frames
|
||||
from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add
|
||||
from tensorflow.contrib.signal.python.ops.shape_ops import frame
|
||||
# `frame` used to be named `frames`, which is a noun and not a verb.
|
||||
# Keep an alias to `frames` for backwards compatibility.
|
||||
from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames
|
||||
from tensorflow.contrib.signal.python.ops.window_ops import hamming_window
|
||||
from tensorflow.contrib.signal.python.ops.window_ops import hann_window
|
||||
|
||||
|
@ -0,0 +1,192 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for reconstruction_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import reconstruction_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReconstructionOpsTest(test.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReconstructionOpsTest, self).__init__(*args, **kwargs)
|
||||
self.batch_size = 3
|
||||
self.frames = 3
|
||||
self.samples = 5
|
||||
|
||||
self.bases = np.array(range(2, 5))
|
||||
exponents = np.array(range(self.frames * self.samples))
|
||||
powers = np.power(self.bases[:, np.newaxis], exponents[np.newaxis, :])
|
||||
|
||||
self.powers = np.reshape(powers, [self.batch_size, self.frames,
|
||||
self.samples])
|
||||
self.frame_hop = 2
|
||||
|
||||
# Hand computed example using powers of unique numbers: this is easily
|
||||
# verified.
|
||||
self.expected_string = ["1", "10", "100100", "1001000", "10010010000",
|
||||
"100100000000", "1001000000000", "10000000000000",
|
||||
"100000000000000"]
|
||||
|
||||
def test_all_ones(self):
|
||||
signal = constant_op.constant(np.ones((3, 5)), dtype=dtypes.int64)
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, 2)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
output = sess.run(reconstruction)
|
||||
|
||||
expected_output = np.array([1, 1, 2, 2, 3, 2, 2, 1, 1])
|
||||
|
||||
self.assertAllClose(output, expected_output)
|
||||
|
||||
def test_simple(self):
|
||||
def make_input(frame_length, num_frames=3):
|
||||
"""Generate a tensor of num_frames frames of frame_length."""
|
||||
return np.reshape(np.arange(1, num_frames * frame_length + 1),
|
||||
(-1, frame_length))
|
||||
|
||||
# List of (signal, expected_result, frame_hop).
|
||||
configurations = [
|
||||
# All hop lengths on a frame length of 2.
|
||||
(make_input(2), [1, 5, 9, 6], 1),
|
||||
(make_input(2), [1, 2, 3, 4, 5, 6], 2),
|
||||
|
||||
# All hop lengths on a frame length of 3.
|
||||
(make_input(3), [1, 6, 15, 14, 9], 1),
|
||||
(make_input(3), [1, 2, 7, 5, 13, 8, 9], 2),
|
||||
(make_input(3), [1, 2, 3, 4, 5, 6, 7, 8, 9], 3),
|
||||
|
||||
# All hop lengths on a frame length of 4.
|
||||
(make_input(4), [1, 7, 18, 21, 19, 12], 1),
|
||||
(make_input(4), [1, 2, 8, 10, 16, 18, 11, 12], 2),
|
||||
(make_input(4), [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3),
|
||||
(make_input(4), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4),
|
||||
]
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
for signal, expected, frame_hop in configurations:
|
||||
reconstruction = reconstruction_ops.overlap_and_add(
|
||||
np.array(signal), frame_hop).eval()
|
||||
expected_output = np.array(expected)
|
||||
self.assertAllClose(reconstruction, expected_output)
|
||||
|
||||
def test_powers(self):
|
||||
signal = constant_op.constant(np.squeeze(self.powers[0, :, :]),
|
||||
dtype=dtypes.int64)
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
output = sess.run(reconstruction)
|
||||
string_output = [np.base_repr(x, self.bases[0]) for x in output]
|
||||
|
||||
self.assertEqual(string_output, self.expected_string)
|
||||
|
||||
def test_batch(self):
|
||||
signal = constant_op.constant(self.powers, dtype=dtypes.int64)
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
output = sess.run(reconstruction)
|
||||
|
||||
accumulator = True
|
||||
for i in range(self.batch_size):
|
||||
string_output = [np.base_repr(x, self.bases[i]) for x in output[i, :]]
|
||||
accumulator = accumulator and (string_output == self.expected_string)
|
||||
|
||||
self.assertTrue(accumulator)
|
||||
|
||||
def test_one_element_batch(self):
|
||||
input_matrix = np.squeeze(self.powers[0, :, :])
|
||||
input_matrix = input_matrix[np.newaxis, :, :].astype(float)
|
||||
signal = constant_op.constant(input_matrix, dtype=dtypes.float32)
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
output = sess.run(reconstruction)
|
||||
|
||||
string_output = [np.base_repr(int(x), self.bases[0]) for x in
|
||||
np.squeeze(output)]
|
||||
|
||||
self.assertEqual(output.shape, (1, 9))
|
||||
self.assertEqual(string_output, self.expected_string)
|
||||
|
||||
def test_gradient(self):
|
||||
configurations = [
|
||||
((1, 128), 1),
|
||||
((5, 35), 17),
|
||||
((10, 128), 128),
|
||||
((2, 10, 128), 127),
|
||||
((2, 2, 10, 128), 126),
|
||||
((2, 2, 2, 10, 128), 125),
|
||||
]
|
||||
|
||||
for shape, frame_hop in configurations:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
signal = array_ops.zeros(shape)
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
|
||||
loss = math_ops.reduce_sum(reconstruction)
|
||||
# Increasing any sample in the input frames by one will increase the sum
|
||||
# of all the samples in the reconstruction by 1, so the gradient should
|
||||
# be all ones, no matter the shape or hop.
|
||||
gradient = sess.run(gradients_impl.gradients([loss], [signal])[0])
|
||||
self.assertTrue((gradient == 1.0).all())
|
||||
|
||||
def test_gradient_batch(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
signal = array_ops.zeros((2, 10, 10))
|
||||
frame_hop = 10
|
||||
reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
|
||||
|
||||
# Multiply the first batch-item's reconstruction by zeros. This will block
|
||||
# gradient from flowing into the first batch item from the loss. Multiply
|
||||
# the second batch item by the integers from 0 to 99. Since there is zero
|
||||
# overlap, the gradient for this batch item will be 0-99 shaped as (10,
|
||||
# 10).
|
||||
reconstruction *= array_ops.stack(
|
||||
[array_ops.zeros((100,)), math_ops.to_float(math_ops.range(100))])
|
||||
loss = math_ops.reduce_sum(reconstruction)
|
||||
|
||||
# Verify that only the second batch item receives gradient.
|
||||
gradient = sess.run(gradients_impl.gradients([loss], [signal])[0])
|
||||
expected_gradient = np.stack([
|
||||
np.zeros((10, 10)),
|
||||
np.reshape(np.arange(100).astype(np.float32), (10, 10))])
|
||||
self.assertAllEqual(expected_gradient, gradient)
|
||||
|
||||
def test_gradient_numerical(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
shape = (2, 10, 10)
|
||||
framed_signal = array_ops.zeros(shape)
|
||||
frame_hop = 10
|
||||
reconstruction = reconstruction_ops.overlap_and_add(
|
||||
framed_signal, frame_hop)
|
||||
error = test.compute_gradient_error(
|
||||
framed_signal, shape, reconstruction, [2, 100])
|
||||
self.assertLess(error, 2e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -24,18 +24,18 @@ from tensorflow.contrib.signal.python.ops import shape_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FramesTest(test.TestCase):
|
||||
class FrameTest(test.TestCase):
|
||||
|
||||
def test_mapping_of_indices_without_padding(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
||||
tensor = array_ops.expand_dims(tensor, 0)
|
||||
|
||||
result = shape_ops.frames(tensor, 512, 180)
|
||||
result = result.eval()
|
||||
result = shape_ops.frame(tensor, 512, 180, pad_end=False).eval()
|
||||
|
||||
expected = np.tile(np.arange(512), (49, 1))
|
||||
expected += np.tile(np.arange(49) * 180, (512, 1)).T
|
||||
@ -46,15 +46,14 @@ class FramesTest(test.TestCase):
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def test_mapping_of_indices_with_padding(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
tensor = constant_op.constant(np.arange(10000), dtypes.int32)
|
||||
tensor = array_ops.expand_dims(tensor, 0)
|
||||
|
||||
result = shape_ops.frames(tensor, 512, 192)
|
||||
result = result.eval()
|
||||
result = shape_ops.frame(tensor, 512, 192, pad_end=True).eval()
|
||||
|
||||
expected = np.tile(np.arange(512), (51, 1))
|
||||
expected += np.tile(np.arange(51) * 192, (512, 1)).T
|
||||
expected = np.tile(np.arange(512), (53, 1))
|
||||
expected += np.tile(np.arange(53) * 192, (512, 1)).T
|
||||
|
||||
expected[expected >= 10000] = 0
|
||||
|
||||
@ -63,6 +62,277 @@ class FramesTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def test_invalid_inputs(self):
|
||||
# Rank 0 input signal.
|
||||
with self.assertRaises(ValueError):
|
||||
shape_ops.frame(1, 1, 1)
|
||||
|
||||
# If the rank is unknown, do not raise an exception.
|
||||
shape_ops.frame(array_ops.placeholder(dtypes.float32), 1, 1)
|
||||
|
||||
# Non-scalar frame_length.
|
||||
with self.assertRaises(ValueError):
|
||||
shape_ops.frame([1], [1], 1)
|
||||
|
||||
# Non-scalar frame_step.
|
||||
with self.assertRaises(ValueError):
|
||||
shape_ops.frame([1], 1, [1])
|
||||
|
||||
# Non-scalar pad_value.
|
||||
with self.assertRaises(ValueError):
|
||||
shape_ops.frame([1], 1, 1, pad_end=True, pad_value=[1])
|
||||
|
||||
def test_length_zero(self):
|
||||
signal = constant_op.constant([], dtype=dtypes.float32)
|
||||
frame_length = 2
|
||||
frame_step = 1
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
self.assertEqual((0, 2), result.shape)
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
self.assertEqual((0, 2), result.shape)
|
||||
|
||||
def test_shape_inference(self):
|
||||
signal = array_ops.placeholder(dtypes.int32, shape=[1, 1])
|
||||
frame_length = 2
|
||||
frame_step = 1
|
||||
# Shape inference is able to detect the rank and inner-most dimension
|
||||
# if frame_length is known at graph definition time.
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99)
|
||||
self.assertEqual([1, 1, 2], result.shape.as_list())
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False)
|
||||
self.assertEqual([1, 0, 2], result.shape.as_list())
|
||||
|
||||
# If frame_length is not known, rank and (known) outer and inner dimensions
|
||||
# are inferred.
|
||||
signal = array_ops.placeholder(dtypes.int32, shape=[1, 2, 3, 4])
|
||||
frame_length = array_ops.placeholder(dtypes.int32, shape=[])
|
||||
frame_step = 1
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99, axis=1)
|
||||
self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False, axis=1)
|
||||
self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
|
||||
|
||||
# If frame_length and inner-most dimension is known, rank, inner dimensions,
|
||||
# and known outer dimensions are inferred.
|
||||
signal = array_ops.placeholder(dtypes.int32,
|
||||
shape=[None, 5, None, 20, 5, 3])
|
||||
frame_length = 4
|
||||
frame_step = 3
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99, axis=3)
|
||||
self.assertEqual([None, 5, None, 7, 4, 5, 3], result.shape.as_list())
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False, axis=3)
|
||||
self.assertEqual([None, 5, None, 6, 4, 5, 3], result.shape.as_list())
|
||||
|
||||
# Test that shape inference is consistent with actual returned shapes for
|
||||
# small values of signal_length, frame_length, frame_step, and pad_end in
|
||||
# [True, False].
|
||||
frame_step = 1
|
||||
for signal_length in range(2):
|
||||
signal = [0] * signal_length
|
||||
for frame_length in range(2):
|
||||
for pad_end in [False, True]:
|
||||
op = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=pad_end, pad_value=99)
|
||||
with self.test_session(use_gpu=True):
|
||||
result = op.eval()
|
||||
self.assertEqual(op.shape.as_list(), list(result.shape))
|
||||
|
||||
def test_basic_mono(self):
|
||||
signal = np.arange(6)
|
||||
frame_length = 3
|
||||
frame_step = 2
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
for rank in range(5):
|
||||
nd_signal = np.reshape(signal, (1,) * rank + signal.shape)
|
||||
|
||||
# With padding, we pad the last frame with pad_value.
|
||||
result = shape_ops.frame(nd_signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4], [4, 5, 99]])
|
||||
expected = np.reshape(
|
||||
expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
# Without padding, we drop the last frame.
|
||||
expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4]])
|
||||
expected = np.reshape(
|
||||
expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
|
||||
result = shape_ops.frame(nd_signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def test_basic_stereo(self):
|
||||
signal = np.vstack([np.arange(6),
|
||||
np.arange(6) + 10])
|
||||
frame_length = 3
|
||||
frame_step = 2
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
for rank in range(5):
|
||||
nd_signal = np.reshape(signal, (1,) * rank + signal.shape)
|
||||
|
||||
# With padding, we pad the last frame with pad_value.
|
||||
result = shape_ops.frame(nd_signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
expected_inner_frames = np.array([
|
||||
[[0, 1, 2], [2, 3, 4], [4, 5, 99]],
|
||||
[[10, 11, 12], [12, 13, 14], [14, 15, 99]]])
|
||||
expected = np.reshape(
|
||||
expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
# Without padding, we drop the last frame.
|
||||
expected_inner_frames = np.array([[[0, 1, 2], [2, 3, 4]],
|
||||
[[10, 11, 12], [12, 13, 14]]])
|
||||
expected = np.reshape(
|
||||
expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
|
||||
result = shape_ops.frame(nd_signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def test_complex_shape(self):
|
||||
signal = np.vstack([np.arange(6),
|
||||
np.arange(6) + 10,
|
||||
np.arange(6) + 20,
|
||||
np.arange(6) + 30,
|
||||
np.arange(6) + 40,
|
||||
np.arange(6) + 50])
|
||||
signal = np.reshape(signal, (2, 1, 3, 1, 6))
|
||||
frame_length = 3
|
||||
frame_step = 2
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
# With padding, we pad the last frame with pad_value.
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
# Resulting shape is (2, 1, 3, 1, 3, 3).
|
||||
expected = [[[[[[0, 1, 2], [2, 3, 4], [4, 5, 99]]],
|
||||
[[[10, 11, 12], [12, 13, 14], [14, 15, 99]]],
|
||||
[[[20, 21, 22], [22, 23, 24], [24, 25, 99]]]]],
|
||||
[[[[[30, 31, 32], [32, 33, 34], [34, 35, 99]]],
|
||||
[[[40, 41, 42], [42, 43, 44], [44, 45, 99]]],
|
||||
[[[50, 51, 52], [52, 53, 54], [54, 55, 99]]]]]]
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
# Resulting shape is (2, 1, 3, 1, 3, 2).
|
||||
expected = [[[[[[0, 1, 2], [2, 3, 4]]],
|
||||
[[[10, 11, 12], [12, 13, 14]]],
|
||||
[[[20, 21, 22], [22, 23, 24]]]]],
|
||||
[[[[[30, 31, 32], [32, 33, 34]]],
|
||||
[[[40, 41, 42], [42, 43, 44]]],
|
||||
[[[50, 51, 52], [52, 53, 54]]]]]]
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def test_axis(self):
|
||||
signal = np.reshape(np.arange(16), (2, 4, 2))
|
||||
with self.test_session(use_gpu=True):
|
||||
result = shape_ops.frame(signal, frame_length=2, frame_step=2,
|
||||
pad_end=True, axis=1)
|
||||
expected = np.reshape(np.arange(16), (2, 2, 2, 2))
|
||||
self.assertAllEqual(expected, result.eval())
|
||||
|
||||
result = shape_ops.frame(signal, frame_length=2, frame_step=1,
|
||||
pad_end=True, axis=1)
|
||||
expected = [[[[0, 1], [2, 3]],
|
||||
[[2, 3], [4, 5]],
|
||||
[[4, 5], [6, 7]],
|
||||
[[6, 7], [0, 0]]],
|
||||
[[[8, 9], [10, 11]],
|
||||
[[10, 11], [12, 13]],
|
||||
[[12, 13], [14, 15]],
|
||||
[[14, 15], [0, 0]]]]
|
||||
self.assertAllEqual(expected, result.eval())
|
||||
|
||||
result = shape_ops.frame(signal, frame_length=3, frame_step=1,
|
||||
pad_end=True, axis=1)
|
||||
expected = [[[[0, 1], [2, 3], [4, 5]],
|
||||
[[2, 3], [4, 5], [6, 7]],
|
||||
[[4, 5], [6, 7], [0, 0]],
|
||||
[[6, 7], [0, 0], [0, 0]]],
|
||||
[[[8, 9], [10, 11], [12, 13]],
|
||||
[[10, 11], [12, 13], [14, 15]],
|
||||
[[12, 13], [14, 15], [0, 0]],
|
||||
[[14, 15], [0, 0], [0, 0]]]]
|
||||
self.assertAllEqual(expected, result.eval())
|
||||
|
||||
def test_window_larger_than_signal(self):
|
||||
signal = constant_op.constant([[1, 2], [11, 12]], dtype=dtypes.float32)
|
||||
frame_length = 4
|
||||
frame_step = 1
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
self.assertAllClose([[[1, 2, 99, 99], [2, 99, 99, 99]],
|
||||
[[11, 12, 99, 99], [12, 99, 99, 99]]], result)
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
self.assertEqual((2, 0, 4), result.shape)
|
||||
|
||||
frame_step = 2
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=True, pad_value=99).eval()
|
||||
self.assertAllClose([[[1, 2, 99, 99]], [[11, 12, 99, 99]]], result)
|
||||
|
||||
result = shape_ops.frame(signal, frame_length, frame_step,
|
||||
pad_end=False).eval()
|
||||
self.assertEqual((2, 0, 4), result.shape)
|
||||
|
||||
def test_preserves_type(self):
|
||||
signal = math_ops.range(10, dtype=dtypes.float64)
|
||||
frame_length = 2
|
||||
frame_step = 3
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
result = shape_ops.frame(signal, frame_length, frame_step)
|
||||
self.assertEqual(result.dtype, signal.dtype)
|
||||
|
||||
def test_dynamic_tensor(self):
|
||||
# Show that frame works even when the dimensions of its input are
|
||||
# not known at graph creation time.
|
||||
input_signal = np.vstack([np.arange(4), np.arange(4) + 10,
|
||||
np.arange(4) + 20])
|
||||
frame_length = 2
|
||||
frame_step = 2
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
signal_placeholder = array_ops.placeholder(shape=(None, None),
|
||||
dtype=dtypes.float32)
|
||||
result = sess.run(shape_ops.frame(
|
||||
signal_placeholder, frame_length, frame_step),
|
||||
feed_dict={signal_placeholder: input_signal})
|
||||
self.assertAllEqual([[[0, 1], [2, 3]],
|
||||
[[10, 11], [12, 13]],
|
||||
[[20, 21], [22, 23]]], result)
|
||||
|
||||
def test_gradient_numerical(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
signal_shape = (2, 128)
|
||||
signal = array_ops.ones(signal_shape)
|
||||
frame_length = 33
|
||||
frame_step = 9
|
||||
frames = shape_ops.frame(signal, frame_length, frame_step)
|
||||
error = test.compute_gradient_error(
|
||||
signal, signal_shape, frames, frames.shape.as_list())
|
||||
self.assertLess(error, 2e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
144
tensorflow/contrib/signal/python/ops/reconstruction_ops.py
Normal file
144
tensorflow/contrib/signal/python/ops/reconstruction_ops.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Signal reconstruction via overlapped addition of frames."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import shape_ops
|
||||
from tensorflow.contrib.signal.python.ops import util_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def _shuffle_to_front(input_tensor, k):
|
||||
"""Shuffles the last `k` indices of `input_tensor` to the front.
|
||||
|
||||
Transposes `input_tensor` to have the last `k` indices at the front. The input
|
||||
may have arbitrary rank and unknown shape.
|
||||
|
||||
Args:
|
||||
input_tensor: A `Tensor` of arbitrary rank and unknown shape.
|
||||
k: A scalar `Tensor` specifying how many indices to shuffle.
|
||||
|
||||
Returns:
|
||||
A tranposed version of `input_tensor` with `k` indices shuffled to the
|
||||
front.
|
||||
|
||||
Raises:
|
||||
ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar.
|
||||
"""
|
||||
k = ops.convert_to_tensor(k, name="k")
|
||||
k.shape.with_rank(0)
|
||||
k_static = tensor_util.constant_value(k)
|
||||
if k_static is not None:
|
||||
input_tensor.shape.with_rank_at_least(k_static)
|
||||
|
||||
rank = array_ops.rank(input_tensor)
|
||||
outer_indices, inner_indices = array_ops.split(math_ops.range(rank),
|
||||
[rank - k, k])
|
||||
permutation = array_ops.concat([inner_indices, outer_indices], 0)
|
||||
|
||||
return array_ops.transpose(input_tensor, perm=permutation)
|
||||
|
||||
|
||||
def overlap_and_add(signal, frame_step, name=None):
|
||||
"""Reconstructs a signal from a framed representation.
|
||||
|
||||
Adds potentially overlapping frames of a signal with shape
|
||||
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
|
||||
The resulting tensor has shape `[..., output_size]` where
|
||||
|
||||
output_size = (frames - 1) * frame_step + frame_length
|
||||
|
||||
Args:
|
||||
signal: A [..., frames, frame_length] `Tensor`. All dimensions may be
|
||||
unknown, and rank must be at least 2.
|
||||
frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be
|
||||
less than or equal to `frame_length`.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with shape `[..., output_size]` containing the overlap-added
|
||||
frames of `signal`'s inner-most two dimensions.
|
||||
|
||||
Raises:
|
||||
ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar
|
||||
integer or `frame_step` is greater than `frame_length`.
|
||||
"""
|
||||
with ops.name_scope(name, "overlap_and_add", [signal, frame_step]):
|
||||
signal = ops.convert_to_tensor(signal, name="signal")
|
||||
signal.shape.with_rank_at_least(2)
|
||||
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
|
||||
frame_step.shape.assert_has_rank(0)
|
||||
if not frame_step.dtype.is_integer:
|
||||
raise ValueError("frame_step must be an integer. Got %s" %
|
||||
frame_step.dtype)
|
||||
|
||||
# If frame_length and frame_step are known at graph construction time, check
|
||||
# frame_step is less than or equal to frame_length.
|
||||
frame_step_static = tensor_util.constant_value(frame_step)
|
||||
if (frame_step_static is not None and signal.shape.ndims is not None and
|
||||
signal.shape[-1].value is not None and
|
||||
frame_step_static > signal.shape[-1].value):
|
||||
raise ValueError(
|
||||
"frame_step (%d) must be less than or equal to frame_length (%d)" % (
|
||||
frame_step_static, signal.shape[-1].value))
|
||||
|
||||
signal_shape = array_ops.shape(signal)
|
||||
|
||||
# All dimensions that are not part of the overlap-and-add. Can be empty for
|
||||
# rank 2 inputs.
|
||||
outer_dimensions = signal_shape[:-2]
|
||||
|
||||
signal_rank = array_ops.rank(signal)
|
||||
frames = signal_shape[-2]
|
||||
frame_length = signal_shape[-1]
|
||||
|
||||
subframe_length = util_ops.gcd(frame_length, frame_step)
|
||||
subframe_step = frame_step // subframe_length
|
||||
subframes_per_frame = frame_length // subframe_length
|
||||
output_size = frame_step * (frames - 1) + frame_length
|
||||
output_subframes = output_size // subframe_length
|
||||
|
||||
# To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe"
|
||||
# level, where a subframe is gcd(frame_length, frame_step). Reshape signal
|
||||
# from [..., frames, frame_length] into [..., subframes, subframe_length].
|
||||
subframe_shape = array_ops.concat(
|
||||
[outer_dimensions, [-1, subframe_length]], 0)
|
||||
subframe_signal = array_ops.reshape(signal, subframe_shape)
|
||||
|
||||
# Now we shuffle the last [subframes, subframe_length] dimensions to the
|
||||
# front.
|
||||
# TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can
|
||||
# avoid this pair of transposes.
|
||||
subframe_signal = _shuffle_to_front(subframe_signal, 2)
|
||||
|
||||
# Use unsorted_segment_sum to add overlapping subframes together.
|
||||
segment_ids = array_ops.reshape(shape_ops.frame(
|
||||
math_ops.range(output_subframes), subframes_per_frame, subframe_step,
|
||||
pad_end=False), [-1])
|
||||
result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids,
|
||||
num_segments=output_subframes)
|
||||
|
||||
# result is a [subframes, subframe_length, ...outer_dimensions] tensor. We
|
||||
# return a [...outer_dimensions, output_size] tensor with a transpose and
|
||||
# reshape.
|
||||
result_shape = array_ops.concat([outer_dimensions, [output_size]], 0)
|
||||
return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2),
|
||||
result_shape)
|
@ -18,70 +18,173 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import util_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def frames(signal, frame_length, frame_step, name=None):
|
||||
"""Frame a signal into overlapping frames.
|
||||
def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
|
||||
"""Infers the shape of the return value of `frame`."""
|
||||
frame_length = tensor_util.constant_value(frame_length)
|
||||
frame_step = tensor_util.constant_value(frame_step)
|
||||
axis = tensor_util.constant_value(axis)
|
||||
if signal.shape.ndims is None:
|
||||
return None
|
||||
if axis is None:
|
||||
return [None] * (signal.shape.ndims + 1)
|
||||
|
||||
May be used in front of spectral functions.
|
||||
signal_shape = signal.shape.as_list()
|
||||
num_frames = None
|
||||
frame_axis = signal_shape[axis]
|
||||
outer_dimensions = signal_shape[:axis]
|
||||
inner_dimensions = signal_shape[axis:][1:]
|
||||
if signal_shape and frame_axis is not None:
|
||||
if frame_step and frame_length is not None:
|
||||
if pad_end:
|
||||
# Double negative is so that we round up.
|
||||
num_frames = -(-frame_axis // frame_step)
|
||||
else:
|
||||
num_frames = (frame_axis - frame_length + frame_step) // frame_step
|
||||
num_frames = max(0, num_frames)
|
||||
return outer_dimensions + [num_frames, frame_length] + inner_dimensions
|
||||
|
||||
|
||||
def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
|
||||
name=None):
|
||||
"""Expands `signal`'s `axis` dimension into frames of `frame_length`.
|
||||
|
||||
Slides a window of size `frame_length` over `signal`s `axis` dimension
|
||||
with a stride of `frame_step`, replacing the `axis` dimension with
|
||||
`[frames, frame_length]` frames.
|
||||
|
||||
If `pad_end` is True, window positions that are past the end of the `axis`
|
||||
dimension are padded with `pad_value` until the window moves fully past the
|
||||
end of the dimension. Otherwise, only window positions that fully overlap the
|
||||
`axis` dimension are produced.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
pcm = tf.placeholder(tf.float32, [None, 9152])
|
||||
frames = tf.contrib.signal.frames(pcm, 512, 180)
|
||||
frames = tf.contrib.signal.frame(pcm, 512, 180)
|
||||
magspec = tf.abs(tf.spectral.rfft(frames, [512]))
|
||||
image = tf.expand_dims(magspec, 3)
|
||||
```
|
||||
|
||||
Args:
|
||||
signal: A `Tensor` of shape `[batch_size, signal_length]`.
|
||||
frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
|
||||
frame_step: An `int32` or `int64` `Tensor`. The step between frames.
|
||||
name: A name for the operation (optional).
|
||||
signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions
|
||||
may be unknown. Rank must be at least 1.
|
||||
frame_length: The frame length in samples. An integer or scalar `Tensor`.
|
||||
frame_step: The frame hop size in samples. An integer or scalar `Tensor`.
|
||||
pad_end: Whether to pad the end of `signal` with `pad_value`.
|
||||
pad_value: An optional scalar `Tensor` to use where the input signal
|
||||
does not exist when `pad_end` is True.
|
||||
axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to
|
||||
the last axis. Supports negative values for indexing from the end.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
|
||||
A `Tensor` of frames with shape `[..., frames, frame_length, ...]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if signal does not have rank 2.
|
||||
ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar.
|
||||
"""
|
||||
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
|
||||
with ops.name_scope(name, "frame", [signal, frame_length, frame_step,
|
||||
pad_value]):
|
||||
signal = ops.convert_to_tensor(signal, name="signal")
|
||||
frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
|
||||
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
|
||||
axis = ops.convert_to_tensor(axis, name="axis")
|
||||
|
||||
signal_rank = signal.shape.ndims
|
||||
signal.shape.with_rank_at_least(1)
|
||||
frame_length.shape.assert_has_rank(0)
|
||||
frame_step.shape.assert_has_rank(0)
|
||||
axis.shape.assert_has_rank(0)
|
||||
|
||||
if signal_rank != 2:
|
||||
raise ValueError("expected signal to have rank 2 but was " + signal_rank)
|
||||
result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end,
|
||||
axis)
|
||||
|
||||
signal_length = array_ops.shape(signal)[1]
|
||||
# Axis can be negative. Convert it to positive.
|
||||
signal_rank = array_ops.rank(signal)
|
||||
axis = math_ops.range(signal_rank)[axis]
|
||||
|
||||
num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
|
||||
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
|
||||
signal_shape = array_ops.shape(signal)
|
||||
outer_dimensions, length_samples, inner_dimensions = array_ops.split(
|
||||
signal_shape, [axis, 1, signal_rank - 1 - axis])
|
||||
length_samples = array_ops.reshape(length_samples, [])
|
||||
num_outer_dimensions = array_ops.size(outer_dimensions)
|
||||
num_inner_dimensions = array_ops.size(inner_dimensions)
|
||||
|
||||
pad_length = (num_frames - 1) * frame_step + frame_length
|
||||
pad_signal = array_ops.pad(signal, [[0, 0], [0,
|
||||
pad_length - signal_length]])
|
||||
# If padding is requested, pad the input signal tensor with pad_value.
|
||||
if pad_end:
|
||||
pad_value = ops.convert_to_tensor(pad_value, signal.dtype)
|
||||
pad_value.shape.assert_has_rank(0)
|
||||
|
||||
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
|
||||
indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
|
||||
# Calculate number of frames, using double negatives to round up.
|
||||
num_frames = -(-length_samples // frame_step)
|
||||
|
||||
indices_step = array_ops.expand_dims(
|
||||
math_ops.range(num_frames) * frame_step, 1)
|
||||
indices_steps = array_ops.tile(indices_step, [1, frame_length])
|
||||
# Pad the signal by up to frame_length samples based on how many samples
|
||||
# are remaining starting from last_frame_position.
|
||||
pad_samples = math_ops.maximum(
|
||||
0, frame_length + frame_step * (num_frames - 1) - length_samples)
|
||||
|
||||
indices = indices_frames + indices_steps
|
||||
# Pad the inner dimension of signal by pad_samples.
|
||||
paddings = array_ops.concat(
|
||||
[array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype),
|
||||
[[0, pad_samples]],
|
||||
array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)],
|
||||
0)
|
||||
signal = array_ops.pad(signal, paddings, constant_values=pad_value)
|
||||
|
||||
# TODO(androbin): remove `transpose` when `gather` gets `axis` support
|
||||
pad_signal = array_ops.transpose(pad_signal)
|
||||
signal_frames = array_ops.gather(pad_signal, indices)
|
||||
signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
|
||||
signal_shape = array_ops.shape(signal)
|
||||
length_samples = signal_shape[axis]
|
||||
else:
|
||||
num_frames = math_ops.maximum(
|
||||
0, 1 + (length_samples - frame_length) // frame_step)
|
||||
|
||||
return signal_frames
|
||||
subframe_length = util_ops.gcd(frame_length, frame_step)
|
||||
subframes_per_frame = frame_length // subframe_length
|
||||
subframes_per_hop = frame_step // subframe_length
|
||||
num_subframes = length_samples // subframe_length
|
||||
|
||||
slice_shape = array_ops.concat([outer_dimensions,
|
||||
[num_subframes * subframe_length],
|
||||
inner_dimensions], 0)
|
||||
subframe_shape = array_ops.concat([outer_dimensions,
|
||||
[num_subframes, subframe_length],
|
||||
inner_dimensions], 0)
|
||||
subframes = array_ops.reshape(array_ops.strided_slice(
|
||||
signal, array_ops.zeros_like(signal_shape),
|
||||
slice_shape), subframe_shape)
|
||||
|
||||
# frame_selector is a [num_frames, subframes_per_frame] tensor
|
||||
# that indexes into the appropriate frame in subframes. For example:
|
||||
# [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]]
|
||||
frame_selector = array_ops.reshape(
|
||||
math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1])
|
||||
|
||||
# subframe_selector is a [num_frames, subframes_per_frame] tensor
|
||||
# that indexes into the appropriate subframe within a frame. For example:
|
||||
# [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
|
||||
subframe_selector = array_ops.reshape(
|
||||
math_ops.range(subframes_per_frame), [1, subframes_per_frame])
|
||||
|
||||
# Adding the 2 selector tensors together produces a [num_frames,
|
||||
# subframes_per_frame] tensor of indices to use with tf.gather to select
|
||||
# subframes from subframes. We then reshape the inner-most
|
||||
# subframes_per_frame dimension to stitch the subframes together into
|
||||
# frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]].
|
||||
selector = frame_selector + subframe_selector
|
||||
|
||||
frames = array_ops.reshape(
|
||||
array_ops.gather(subframes, selector, axis=axis),
|
||||
array_ops.concat([outer_dimensions, [num_frames, frame_length],
|
||||
inner_dimensions], 0))
|
||||
|
||||
if result_shape:
|
||||
frames.set_shape(result_shape)
|
||||
return frames
|
||||
|
57
tensorflow/contrib/signal/python/ops/util_ops.py
Normal file
57
tensorflow/contrib/signal/python/ops/util_ops.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility ops shared across tf.contrib.signal."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def gcd(a, b, name=None):
|
||||
"""Returns the greatest common divisor via Euclid's algorithm.
|
||||
|
||||
Args:
|
||||
a: The dividend. A scalar integer `Tensor`.
|
||||
b: The divisor. A scalar integer `Tensor`.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` representing the greatest common divisor between `a` and
|
||||
`b`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `a` or `b` are not scalar integers.
|
||||
"""
|
||||
with ops.name_scope(name, 'gcd', [a, b]):
|
||||
a = ops.convert_to_tensor(a)
|
||||
b = ops.convert_to_tensor(b)
|
||||
|
||||
a.shape.assert_has_rank(0)
|
||||
b.shape.assert_has_rank(0)
|
||||
|
||||
if not a.dtype.is_integer:
|
||||
raise ValueError('a must be an integer type. Got: %s' % a.dtype)
|
||||
if not b.dtype.is_integer:
|
||||
raise ValueError('b must be an integer type. Got: %s' % b.dtype)
|
||||
|
||||
cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b))
|
||||
body = lambda a, b: [b, math_ops.mod(a, b)]
|
||||
a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False)
|
||||
return a
|
@ -207,7 +207,8 @@ void ClassificationStats::AddExample(
|
||||
}
|
||||
|
||||
void ClassificationStats::CheckPrune() {
|
||||
if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
|
||||
if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
|
||||
weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
|
||||
return;
|
||||
}
|
||||
++prune_sample_epoch_;
|
||||
|
@ -1514,6 +1514,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
params.input_device_contexts = &input_device_contexts;
|
||||
params.input_alloc_attrs = &input_alloc_attrs;
|
||||
params.runner = &runner_;
|
||||
params.stats_collector = stats_collector_;
|
||||
|
||||
Status s;
|
||||
NodeExecStats* stats = nullptr;
|
||||
|
@ -261,6 +261,7 @@ class CallOp : public AsyncOpKernel {
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.step_container = ctx->step_container();
|
||||
opts.stats_collector = ctx->stats_collector();
|
||||
opts.runner = ctx->runner();
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(ctx->num_inputs());
|
||||
@ -545,6 +546,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
// Inherit the step_id from the caller.
|
||||
exec_args.step_id = opts.step_id;
|
||||
exec_args.step_container = opts.step_container;
|
||||
|
||||
exec_args.stats_collector = opts.stats_collector;
|
||||
exec_args.call_frame = frame;
|
||||
exec_args.cancellation_manager = opts.cancellation_manager;
|
||||
exec_args.runner = *opts.runner;
|
||||
|
@ -38,6 +38,7 @@ class GraphDef;
|
||||
class OpKernel;
|
||||
class ResourceMgr;
|
||||
class ScopedStepContainer;
|
||||
class StepStatsCollector;
|
||||
class Node;
|
||||
|
||||
// FunctionDefHelper::Create is a convenient helper to construct a
|
||||
@ -402,7 +403,8 @@ class FunctionLibraryRuntime {
|
||||
int64 step_id = 0;
|
||||
|
||||
// Per-step container.
|
||||
ScopedStepContainer* step_container;
|
||||
ScopedStepContainer* step_container = nullptr;
|
||||
StepStatsCollector* stats_collector = nullptr;
|
||||
|
||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||
};
|
||||
|
@ -72,6 +72,7 @@ class OpKernelContext; // declared below
|
||||
class OpRegistryInterface;
|
||||
class ResourceMgr;
|
||||
class ScopedStepContainer;
|
||||
class StepStatsCollector;
|
||||
|
||||
class OpKernel {
|
||||
public:
|
||||
@ -551,6 +552,7 @@ class OpKernelContext {
|
||||
FunctionCallFrame* call_frame = nullptr;
|
||||
FunctionLibraryRuntime* function_library = nullptr;
|
||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||
StepStatsCollector* stats_collector = nullptr;
|
||||
|
||||
// TensorSliceReaderCache support.
|
||||
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
||||
@ -940,6 +942,9 @@ class OpKernelContext {
|
||||
std::function<void(std::function<void()>)>* runner() const {
|
||||
return params_->runner;
|
||||
}
|
||||
StepStatsCollector* stats_collector() const {
|
||||
return params_->stats_collector;
|
||||
}
|
||||
|
||||
// Shared resources accessible to this kernel.
|
||||
ResourceMgr* resource_manager() const { return params_->resource_manager; }
|
||||
|
@ -115,6 +115,7 @@ void ConcatGPU(
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int64(REGISTER);
|
||||
REGISTER(bfloat16);
|
||||
|
||||
#undef REGISTER
|
||||
|
@ -201,21 +201,25 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_complex64(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT32);
|
||||
REGISTER_GPUCONCAT32(bfloat16);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex64(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT64);
|
||||
REGISTER_GPUCONCAT64(bfloat16);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
|
||||
TF_CALL_complex64(REGISTER_GPU32);
|
||||
TF_CALL_complex128(REGISTER_GPU32);
|
||||
TF_CALL_int64(REGISTER_GPU32);
|
||||
REGISTER_GPU32(bfloat16);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
|
||||
TF_CALL_complex64(REGISTER_GPU64);
|
||||
TF_CALL_complex128(REGISTER_GPU64);
|
||||
TF_CALL_int64(REGISTER_GPU64);
|
||||
REGISTER_GPU64(bfloat16);
|
||||
|
||||
#undef REGISTER_GPUCONCAT32
|
||||
|
@ -195,6 +195,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -216,7 +216,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
|
||||
TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected context_dense_keys[", di,
|
||||
"] to be a vector, got shape: ",
|
||||
"] to be a scalar, got shape: ",
|
||||
context_dense_keys[di].shape().DebugString()));
|
||||
context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()();
|
||||
}
|
||||
@ -225,7 +225,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
|
||||
TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected context_sparse_keys[", di,
|
||||
"] to be a vector, got shape: ",
|
||||
"] to be a scalar, got shape: ",
|
||||
context_sparse_keys[di].shape().DebugString()));
|
||||
context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()();
|
||||
}
|
||||
@ -234,7 +234,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
|
||||
ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected feature_list_dense_keys[", di,
|
||||
"] to be a vector, got shape: ",
|
||||
"] to be a scalar, got shape: ",
|
||||
feature_list_dense_keys[di].shape().DebugString()));
|
||||
feature_list_dense_keys_t[di] =
|
||||
feature_list_dense_keys[di].scalar<string>()();
|
||||
@ -244,7 +244,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
|
||||
ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected feature_list_sparse_keys[", di,
|
||||
"] to be a vector, got shape: ",
|
||||
"] to be a scalar, got shape: ",
|
||||
feature_list_sparse_keys[di].shape().DebugString()));
|
||||
feature_list_sparse_keys_t[di] =
|
||||
feature_list_sparse_keys[di].scalar<string>()();
|
||||
|
@ -157,6 +157,7 @@ REGISTER_PACK(string);
|
||||
PackOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
298
tensorflow/docs_src/get_started/export.md
Normal file
298
tensorflow/docs_src/get_started/export.md
Normal file
@ -0,0 +1,298 @@
|
||||
# Exporting a Trained Model for Serving
|
||||
|
||||
Once you have trained an `Estimator` model, you may want to create a service
|
||||
from that model that takes requests and returns a result. You can run such a
|
||||
service locally on your machine or deploy it scalably in the cloud.
|
||||
|
||||
To prepare a trained Estimator for serving, you must export it in the standard
|
||||
[`SavedModel`](https://www.tensorflow.org/code/tensorflow/python/saved_model/README.md)
|
||||
format, which wraps the TensorFlow graph, the trained variable values, any
|
||||
required assets, and metadata together in a hermetic package.
|
||||
|
||||
In this tutorial, we will discuss how to:
|
||||
|
||||
* Add graph nodes that accept and prepare inference requests
|
||||
* Specify the output nodes and the corresponding [APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto)
|
||||
that can be served (Classify, Regress, or Predict)
|
||||
* Export your model to the `SavedModel` format
|
||||
* Deploy the model in Google Cloud ML Engine and request predictions
|
||||
* Serve the model from a local server and request predictions
|
||||
|
||||
|
||||
## The exported graph and its signatures
|
||||
|
||||
The export procedure assembles a new TensorFlow graph from two main components:
|
||||
1) a Serving Input Receiver that defines the format of the inputs to be
|
||||
accepted, and
|
||||
2) the trained model itself.
|
||||
|
||||
An exported `SavedModel` contains that combined graph packaged together with one
|
||||
or more *signatures*. Like a function signature in any programming language, a
|
||||
graph signature specifies the required inputs (arguments) and the
|
||||
expected outputs (return values) of performing the computation. In the typical
|
||||
case, a single signature is present, corresponding to the predictions that the
|
||||
model has learned to make.
|
||||
|
||||
The *input* portion of the signature is determined by the Serving Input
|
||||
Receiver. To specify the inputs that your deployed model will accept, you must
|
||||
provide a `serving_input_receiver_fn()` to `estimator.export_savedmodel()` (see
|
||||
below).
|
||||
|
||||
The *output* portion of the signature is determined by the model. For instance,
|
||||
canned Estimators know the nature of the outputs they produce (e.g. whether the
|
||||
output is a classification or a regression, and the type and shape of those
|
||||
outputs). Custom Estimators must provide this information via `export_outputs`
|
||||
(see [below](#specifying_the_outputs_of_a_custom_model)).
|
||||
|
||||
> Note: A *multi-headed model* provides multiple signatures, each corresponding
|
||||
> to a different "head", i.e. a set of predictions that can be made from the
|
||||
> same inputs by executing a subgraph of the complete trained graph. The
|
||||
> *output* portions of these signatures are determined by the model.
|
||||
|
||||
|
||||

|
||||
|
||||
## Preparing serving inputs
|
||||
|
||||
During training, an @{$input_fn$`input_fn()`} ingests data and prepares it for
|
||||
use by the model. At serving time, similarly, a `serving_input_receiver_fn()`
|
||||
accepts inference requests and prepares them for the model. The purpose of this
|
||||
function is to add placeholders to the graph which the serving system will feed
|
||||
with inference requests, as well as to add any additional ops needed to convert
|
||||
data from the input format into the feature `Tensor`s expected by the model.
|
||||
The function returns a @{tf.estimator.export.ServingInputReceiver} object, which
|
||||
packages the placeholders and the resulting feature `Tensor`s together.
|
||||
|
||||
A typical pattern is that inference requests arrive in the form of serialized
|
||||
`tf.Example`s, so the `serving_input_receiver_fn()` creates a single string
|
||||
placeholder to receive them. The `serving_input_receiver_fn()` is then also
|
||||
responsible for parsing the `tf.Example`s by adding a @{tf.parse_example} op to
|
||||
the graph.
|
||||
|
||||
When writing such a `serving_input_receiver_fn()`, you must pass a parsing
|
||||
specification to @{tf.parse_example} to tell the parser what feature names to
|
||||
expect and how to map them to `Tensor`s. A parsing specification takes the
|
||||
form of a dict from feature names to @{tf.FixedLenFeature}, @{tf.VarLenFeature},
|
||||
and @{tf.SparseFeature}. (Note this parsing specification should not include
|
||||
any label or weight columns, since those will not be available at serving
|
||||
time—in contrast to a parsing specification used in the `input_fn()` at
|
||||
training time.)
|
||||
|
||||
In combination, then:
|
||||
|
||||
```py
|
||||
feature_spec = {'foo': tf.FixedLenFeature(...),
|
||||
'bar': tf.VarLenFeature(...)}
|
||||
|
||||
def serving_input_receiver_fn():
|
||||
"""An input receiver that expects a serialized tf.Example."""
|
||||
serialized_tf_example = tf.placeholder(dtype=tf.string,
|
||||
shape=[default_batch_size],
|
||||
name='input_example_tensor')
|
||||
receiver_tensors = {'examples': serialized_tf_example}
|
||||
features = tf.parse_example(serialized_tf_example, feature_spec)
|
||||
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
|
||||
```
|
||||
|
||||
The @{tf.estimator.export.build_parsing_serving_input_receiver_fn} utility
|
||||
function provides that input receiver for the common case.
|
||||
|
||||
> Note: when training a model to be served using Google Cloud ML Engine (see
|
||||
> below), the parsing step is not needed, because the model will receive raw
|
||||
> feature data. This is also true when using the Predict API with a local
|
||||
> server.
|
||||
|
||||
Even if you require no parsing or other input processing—i.e., if the
|
||||
serving system will feed feature `Tensor`s directly—you must still provide
|
||||
a `serving_input_receiver_fn()` that creates placeholders for the feature
|
||||
`Tensor`s and passes them through. The
|
||||
@{tf.estimator.export.build_raw_serving_input_receiver_fn} utility provides for
|
||||
this.
|
||||
|
||||
If these utilities do not meet your needs, you are free to write your own
|
||||
`serving_input_receiver_fn()`. One case where this may be needed is if your
|
||||
training `input_fn()` incorporates some preprocessing logic that must be
|
||||
recapitulated at serving time. To reduce the risk of training-serving skew, we
|
||||
recommend encapsulating such processing in a function which is then called
|
||||
from both `input_fn()` and `serving_input_receiver_fn()`.
|
||||
|
||||
|
||||
## Performing the export
|
||||
|
||||
To export your trained Estimator, call
|
||||
@{tf.estimator.Estimator.export_savedmodel} with the export base path, together
|
||||
with the `serving_input_receiver_fn`.
|
||||
|
||||
```py
|
||||
estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)
|
||||
```
|
||||
|
||||
This method builds a new graph by first calling the
|
||||
`serving_input_receiver_fn()` to obtain feature `Tensor`s, and then calling
|
||||
this `Estimator`'s `model_fn()` to generate the model graph based on those
|
||||
features. It starts a fresh `Session`, and, by default, restores the most recent
|
||||
checkpoint into it. (A different checkpoint may be passed, if needed.)
|
||||
Finally it creates a timestamped export directory below the given
|
||||
`export_dir_base` (i.e., `export_dir_base/<timestamp>`), and writes a
|
||||
`SavedModel` into it containing a single `MetaGraphDef` saved from this
|
||||
Session.
|
||||
|
||||
> Note: there is currently no built-in mechanism to garbage-collect old exports,
|
||||
> so successive exports will accumulate under `export_dir_base` unless deleted
|
||||
> by some external means.
|
||||
|
||||
## Specifying the outputs of a custom model
|
||||
|
||||
When writing a custom `model_fn`, you must populate the `export_outputs` element
|
||||
of the @{tf.estimator.EstimatorSpec} return value. This is a dict of
|
||||
`{name: output}` describing the output signatures to be exported and used during
|
||||
serving.
|
||||
|
||||
In the usual case of making a single prediction, this dict contains
|
||||
one element, and the `name` is immaterial. In a multi-headed model, each head
|
||||
is represented by an entry in this dict. In this case the `name` is a string
|
||||
of your choice that can be used to request a specific head at serving time.
|
||||
|
||||
Each `output` value must be an `ExportOutput` object such as
|
||||
@{tf.estimator.export.ClassificationOutput},
|
||||
@{tf.estimator.export.RegressionOutput}, or
|
||||
@{tf.estimator.export.PredictOutput}.
|
||||
|
||||
These output types map straightforwardly to the
|
||||
[TensorFlow Serving APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto),
|
||||
and so determine which request types will be honored.
|
||||
|
||||
> Note: In the multi-headed case, a `SignatureDef` will be generated for each
|
||||
> element of the `export_outputs` dict returned from the model_fn, named using
|
||||
> the same keys. These signatures differ only in their outputs, as provided by
|
||||
> the corresponding `ExportOutput` entry. The inputs are always those provided
|
||||
> by the `serving_input_receiver_fn`.
|
||||
> An inference request may specify the head by name. One head must be named
|
||||
> using [`signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`](https://www.tensorflow.org/code/saved_model/signature_constants.py)
|
||||
> indicating which signature will be served when an inference request does not
|
||||
> specify one.
|
||||
|
||||
|
||||
## Serving the exported model on Google Cloud ML Engine
|
||||
|
||||
[Google Cloud ML Engine](https://cloud.google.com/ml-engine/) provides a fully
|
||||
managed, scalable environment for serving your trained SavedModels to make
|
||||
online or batch predictions.
|
||||
|
||||
Please see [Deploying Models](https://cloud.google.com/ml-engine/docs/how-tos/deploying-models)
|
||||
to learn how to deploy your SavedModel on Cloud ML Engine.
|
||||
|
||||
> Note: Cloud ML Engine accepts inference requests in JSON, CSV, or TFRecords
|
||||
> formats, depending on the circumstance. Parsing these formats is not the
|
||||
> responsibility of the graph. Cloud ML Engine does the parsing for you, and
|
||||
> feeds raw feature data directly into the graph. Thus, when targeting Cloud ML
|
||||
> Engine, you should use a `serving_input_receiver_fn()` of the passthrough form
|
||||
> that simply creates placeholders for each feature.
|
||||
|
||||
|
||||
## Requesting predictions from Google Cloud ML Engine
|
||||
|
||||
To learn how to request predictions from a model deployed in Cloud ML Engine,
|
||||
please see:
|
||||
|
||||
* [Prediction Basics](https://cloud.google.com/ml-engine/docs/concepts/prediction-overview)
|
||||
* [Getting Online Predictions](https://cloud.google.com/ml-engine/docs/how-tos/online-predict)
|
||||
* [Getting Batch Predictions](https://cloud.google.com/ml-engine/docs/how-tos/batch-predict)
|
||||
|
||||
|
||||
## Serving the exported model locally
|
||||
|
||||
For local deployment, you can serve your model using
|
||||
@{$deploy/tfserve$Tensorflow Serving}, an open-source project that loads a
|
||||
`SavedModel` and exposes it as a [gRPC](http://www.grpc.io/) service.
|
||||
|
||||
First, [install TensorFlow Serving](https://tensorflow.github.io/serving/setup#prerequisites).
|
||||
|
||||
Then build and run the local model server, substituting `$export_dir_base` with
|
||||
the path to the `SavedModel` you exported above:
|
||||
|
||||
```sh
|
||||
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
|
||||
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_base_path=$export_dir_base
|
||||
```
|
||||
|
||||
Now you have a server listening for inference requests via gRPC on port 9000!
|
||||
|
||||
|
||||
## Requesting predictions from a local server
|
||||
|
||||
The server responds to gRPC requests according to the [PredictionService](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto#L15)
|
||||
gRPC API service definition. (The nested protocol buffers are defined in
|
||||
various [neighboring files](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis)).
|
||||
|
||||
From the API service definition, the gRPC framework generates client libraries
|
||||
in various languages providing remote access to the API. In a project using the
|
||||
Bazel build tool, these libraries are built automatically and provided via
|
||||
dependencies like these (using Python for example):
|
||||
|
||||
```build
|
||||
deps = [
|
||||
"//tensorflow_serving/apis:classification_proto_py_pb2",
|
||||
"//tensorflow_serving/apis:regression_proto_py_pb2",
|
||||
"//tensorflow_serving/apis:predict_proto_py_pb2",
|
||||
"//tensorflow_serving/apis:prediction_service_proto_py_pb2"
|
||||
]
|
||||
```
|
||||
|
||||
Python client code can then import the libraries thus:
|
||||
|
||||
```py
|
||||
from tensorflow_serving.apis import classification_pb2
|
||||
from tensorflow_serving.apis import regression_pb2
|
||||
from tensorflow_serving.apis import predict_pb2
|
||||
from tensorflow_serving.apis import prediction_service_pb2
|
||||
```
|
||||
|
||||
> Note: `prediction_service_pb2` defines the service as a whole and so
|
||||
> is always required. However a typical client will need only one of
|
||||
> `classification_pb2`, `regression_pb2`, and `predict_pb2`, depending on the
|
||||
> type of requests being made.
|
||||
|
||||
Sending a gRPC request is then accomplished by assembling a protocol buffer
|
||||
containing the request data and passing it to the service stub. Note how the
|
||||
request protocol buffer is created empty and then populated via the
|
||||
[generated protocol buffer API](https://developers.google.com/protocol-buffers/docs/reference/python-generated).
|
||||
|
||||
```py
|
||||
from grpc.beta import implementations
|
||||
|
||||
channel = implementations.insecure_channel(host, int(port))
|
||||
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
|
||||
|
||||
request = classification_pb2.ClassificationRequest()
|
||||
example = request.input.example_list.examples.add()
|
||||
example.features.feature['x'].float_list.value.extend(image[0].astype(float))
|
||||
|
||||
result = stub.Classify(request, 10.0) # 10 secs timeout
|
||||
```
|
||||
|
||||
The returned result in this example is a `ClassificationResponse` protocol
|
||||
buffer.
|
||||
|
||||
This is a skeletal example; please see the @{$deploy$Tensorflow Serving}
|
||||
documentation and [examples](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/example)
|
||||
for more details.
|
||||
|
||||
> Note: `ClassificationRequest` and `RegressionRequest` contain a
|
||||
> `tensorflow.serving.Input` protocol buffer, which in turn contains a list of
|
||||
> `tensorflow.Example` protocol buffers. `PredictRequest`, by contrast,
|
||||
> contains a mapping from feature names to values encoded via `TensorProto`.
|
||||
> Correspondingly: When using the `Classify` and `Regress` APIs, TensorFlow
|
||||
> Serving feeds serialized `tf.Example`s to the graph, so your
|
||||
> `serving_input_receiver_fn()` should include a `tf.parse_example()` Op.
|
||||
> When using the generic `Predict` API, however, TensorFlow Serving feeds raw
|
||||
> feature data to the graph, so a passthrough `serving_input_receiver_fn()`
|
||||
> should be used.
|
||||
|
||||
|
||||
<!-- TODO(soergel): give examples of making requests against this server, using
|
||||
the different Tensorflow Serving APIs, selecting the signature by key, etc. -->
|
||||
|
||||
<!-- TODO(soergel): document ExportStrategy here once Experiment moves
|
||||
from contrib to core. -->
|
||||
|
@ -26,6 +26,8 @@ To learn about the high-level API, read the following guides:
|
||||
which takes you into a somewhat more sophisticated use of this API.
|
||||
* @{$get_started/monitors$Logging and Monitoring Basics with tf.contrib.learn},
|
||||
which explains how to audit the progress of model training.
|
||||
* @{$get_started/export$Exporting a Trained Model for Serving}, which shows
|
||||
how to save a trained model in a form that is ready to deploy.
|
||||
|
||||
TensorBoard is a utility to visualize different aspects of machine learning.
|
||||
The following guides explain how to use TensorBoard:
|
||||
|
@ -141,6 +141,7 @@ class ConcatOpTest(test.TestCase):
|
||||
self._testRandom(dtypes.float32)
|
||||
self._testRandom(dtypes.int16)
|
||||
self._testRandom(dtypes.int32)
|
||||
self._testRandom(dtypes.int64)
|
||||
self._testRandom(dtypes.bfloat16)
|
||||
self._testRandom(dtypes.complex64)
|
||||
self._testRandom(dtypes.complex128)
|
||||
|
@ -45,45 +45,61 @@ class StackOpTest(test.TestCase):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
data = np.random.randn(*shape)
|
||||
# Convert [data[0], data[1], ...] separately to tensorflow
|
||||
# TODO(irving): Remove list() once we handle maps correctly
|
||||
xs = list(map(constant_op.constant, data))
|
||||
# Pack back into a single tensorflow tensor
|
||||
c = array_ops.stack(xs)
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
for dtype in [np.float32, np.int32, np.int64]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Convert [data[0], data[1], ...] separately to tensorflow
|
||||
# TODO(irving): Remove list() once we handle maps correctly
|
||||
xs = list(map(constant_op.constant, data))
|
||||
# Pack back into a single tensorflow tensor
|
||||
c = array_ops.stack(xs)
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
def testSimpleParallel(self):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
data = np.random.randn(*shape).astype(np.float32)
|
||||
xs = list(map(constant_op.constant, data))
|
||||
c = array_ops.parallel_stack(xs)
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
def testConst(self):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
for dtype in [np.float32, np.int32, np.int64]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Pack back into a single tensorflow tensor directly using np array
|
||||
c = array_ops.stack(data)
|
||||
# This is implemented via a Const:
|
||||
self.assertEqual(c.op.type, "Const")
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
# Python lists also work for 1-D case:
|
||||
if len(shape) == 1:
|
||||
data_list = list(data)
|
||||
cl = array_ops.stack(data_list)
|
||||
self.assertEqual(cl.op.type, "Const")
|
||||
self.assertAllEqual(cl.eval(), data)
|
||||
|
||||
# Verify that shape induction works with shapes produced via const stack
|
||||
a = constant_op.constant([1, 2, 3, 4, 5, 6])
|
||||
b = array_ops.reshape(a, array_ops.stack([2, 3]))
|
||||
self.assertAllEqual(b.get_shape(), [2, 3])
|
||||
|
||||
def testConstParallel(self):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
data = np.random.randn(*shape).astype(np.float32)
|
||||
# Pack back into a single tensorflow tensor directly using np array
|
||||
c = array_ops.stack(data)
|
||||
# This is implemented via a Const:
|
||||
self.assertEqual(c.op.type, "Const")
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
c = array_ops.parallel_stack(data)
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
# Python lists also work for 1-D case:
|
||||
if len(shape) == 1:
|
||||
data_list = list(data)
|
||||
cl = array_ops.stack(data_list)
|
||||
self.assertEqual(cl.op.type, "Const")
|
||||
self.assertAllEqual(cl.eval(), data)
|
||||
|
||||
cl = array_ops.parallel_stack(data_list)
|
||||
self.assertAllEqual(cl.eval(), data)
|
||||
|
||||
# Verify that shape induction works with shapes produced via const stack
|
||||
a = constant_op.constant([1, 2, 3, 4, 5, 6])
|
||||
b = array_ops.reshape(a, array_ops.stack([2, 3]))
|
||||
self.assertAllEqual(b.get_shape(), [2, 3])
|
||||
data = np.random.randn(*shape).astype(np.float32)
|
||||
c = array_ops.parallel_stack(data)
|
||||
self.assertAllEqual(c.eval(), data)
|
||||
|
||||
def testGradientsAxis0(self):
|
||||
np.random.seed(7)
|
||||
|
@ -69,10 +69,13 @@ ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow
|
||||
|
||||
# Do not run tests with "no_pip" tag. If running GPU tests, also do not run
|
||||
# tests with no_pip_gpu tag.
|
||||
PIP_TEST_FILTER_TAG="-no_pip"
|
||||
PIP_TEST_FILTER_TAG="-no_oss,-no_pip"
|
||||
if [[ ${IS_GPU} == "1" ]]; then
|
||||
PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
|
||||
fi
|
||||
if [[ ${IS_MAC} == "1" ]]; then
|
||||
PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}"
|
||||
fi
|
||||
|
||||
# Bazel flags we need for all tests:
|
||||
# define=no_tensorflow_py_deps=true, to skip all test dependencies.
|
||||
|
Loading…
x
Reference in New Issue
Block a user