Merge changes from github.
PiperOrigin-RevId: 189945839
This commit is contained in:
parent
cbede3ea75
commit
2d0531d72c
@ -22,6 +22,10 @@ organization for the purposes of conducting machine learning and deep neural
|
||||
networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
|
||||
## Installation
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
|
||||
|
||||
|
16
SECURITY.md
16
SECURITY.md
@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow.
|
||||
|
||||
## TensorFlow models are programs
|
||||
|
||||
TensorFlow's runtime system interprets and executes programs. What machine
|
||||
TensorFlow's runtime system interprets and executes programs. What machine
|
||||
learning practitioners term
|
||||
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
|
||||
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
|
||||
@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run
|
||||
inference on the data.
|
||||
|
||||
**TensorFlow models are programs, and need to be treated as such from a security
|
||||
perspective.**
|
||||
perspective.**
|
||||
|
||||
## Running untrusted models
|
||||
|
||||
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
|
||||
[nsjail](https://github.com/google/nsjail)).
|
||||
[nsjail](https://github.com/google/nsjail)).
|
||||
|
||||
There are several ways in which a model could become untrusted. Obviously, if an
|
||||
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
|
||||
@ -109,11 +109,11 @@ graphs known to the `ModelServer`. This means that an attacker may run
|
||||
graphs using untrusted inputs as described above, but they would not be able to
|
||||
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
|
||||
directly to an untrusted network, **but only if the graphs it is configured to
|
||||
use have been carefully audited to be safe**.
|
||||
use have been carefully audited to be safe**.
|
||||
|
||||
Similar to best practices for other servers, we recommend running any
|
||||
`ModelServer` with appropriate privileges (i.e., using a separate user with
|
||||
reduced permisisons). In the spirit of defense in depth, we recommend
|
||||
reduced permissions). In the spirit of defense in depth, we recommend
|
||||
authenticating requests to any TensorFlow server connected to an untrusted
|
||||
network, as well as sandboxing the server to minimize the adverse effects of
|
||||
any breach.
|
||||
@ -129,11 +129,11 @@ with specially crafted inputs.
|
||||
### What is a vulnerability?
|
||||
|
||||
Given TensorFlow's flexibility, it is possible to specify computation graphs
|
||||
which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
|
||||
which exhibit unexpected or unwanted behavior. The fact that TensorFlow models
|
||||
can perform arbitrary computations means that they may read and write files,
|
||||
communicate via the network, produce deadlocks and infinite loops, or run out
|
||||
of memory. It is only when these behaviors are outside the specifications of the
|
||||
operations involved that such behavior is a vulnerability.
|
||||
operations involved that such behavior is a vulnerability.
|
||||
|
||||
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
|
||||
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
|
||||
@ -168,7 +168,7 @@ below).
|
||||
|
||||
Please use a descriptive subject line for your report email. After the initial
|
||||
reply to your report, the security team will endeavor to keep you informed of
|
||||
the progress being made towards a fix and announcement.
|
||||
the progress being made towards a fix and announcement.
|
||||
|
||||
If you believe that an existing (public) issue is security-related, please send
|
||||
an email to `security@tensorflow.org`. The email should include the issue ID and
|
||||
|
@ -1048,7 +1048,10 @@ def set_tf_tensorrt_install_path(environ_cp):
|
||||
|
||||
for lib_file in possible_files:
|
||||
if is_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
ver_str = nvinfer_pattern.search(lib_file).group(1)
|
||||
matches = nvinfer_pattern.search(lib_file)
|
||||
if len(matches.groups()) == 0:
|
||||
continue
|
||||
ver_str = matches.group(1)
|
||||
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
|
||||
if ver > highest_ver[0]:
|
||||
highest_ver = [ver, ver_str, lib_file]
|
||||
|
@ -38,14 +38,7 @@ namespace xla {
|
||||
|
||||
GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
|
||||
size_t pointer_size)
|
||||
: platform_id_(platform_id), pointer_size_(pointer_size) {
|
||||
// We currently only support kHostPlatformId for CPU, kCudaPlatformId for
|
||||
// GPU and kInterpreterPlatformId for Interpreter. Before supporting other
|
||||
// platforms, we need to test this transfer manager on them.
|
||||
CHECK(platform_id_ == se::host::kHostPlatformId ||
|
||||
platform_id_ == se::interpreter::kInterpreterPlatformId ||
|
||||
platform_id_ == se::cuda::kCudaPlatformId);
|
||||
}
|
||||
: platform_id_(platform_id), pointer_size_(pointer_size) {}
|
||||
|
||||
se::Platform::Id GenericTransferManager::PlatformId() const {
|
||||
return platform_id_;
|
||||
|
@ -723,7 +723,7 @@ INSTANTIATE_TEST_CASE_P(
|
||||
);
|
||||
#endif
|
||||
|
||||
TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
|
||||
|
@ -121,6 +121,7 @@ cc_library(
|
||||
"//tensorflow/contrib/coder:all_kernels",
|
||||
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels",
|
||||
"//tensorflow/contrib/data/kernels:dataset_kernels",
|
||||
"//tensorflow/contrib/kafka:dataset_kernels",
|
||||
"//tensorflow/contrib/factorization/kernels:all_kernels",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
@ -147,7 +148,7 @@ cc_library(
|
||||
"//tensorflow/contrib/factorization:all_ops",
|
||||
"//tensorflow/contrib/framework:all_ops",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||
"//tensorflow/contrib/kafka:kafka_ops_op_lib",
|
||||
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
|
||||
|
@ -26,7 +26,7 @@ The CMake files in this directory can build the core TensorFlow runtime, an
|
||||
example C++ binary, and a PIP package containing the runtime and Python
|
||||
bindings.
|
||||
|
||||
### Pre-requisites
|
||||
### Prerequisites
|
||||
|
||||
* CMake version 3.5 or later.
|
||||
|
||||
@ -34,14 +34,16 @@ bindings.
|
||||
|
||||
* [SWIG](http://www.swig.org/download.html)
|
||||
|
||||
* Additional pre-requisites for Microsoft Windows:
|
||||
* Additional prerequisites for Microsoft Windows:
|
||||
- Visual Studio 2015
|
||||
- Python 3.5
|
||||
- NumPy 1.11.0 or later
|
||||
|
||||
* Additional pre-requisites for Linux:
|
||||
* Additional prerequisites for Linux:
|
||||
- Python 2.7 or later
|
||||
- [Docker](https://www.docker.com/) (for automated testing)
|
||||
|
||||
* Python dependencies:
|
||||
- wheel
|
||||
- NumPy 1.11.0 or later
|
||||
|
||||
### Known-good configurations
|
||||
@ -102,7 +104,7 @@ ops or APIs.
|
||||
Step-by-step Windows build
|
||||
==========================
|
||||
|
||||
1. Install the pre-requisites detailed above, and set up your environment.
|
||||
1. Install the prerequisites detailed above, and set up your environment.
|
||||
|
||||
* The following commands assume that you are using the Windows Command
|
||||
Prompt (`cmd.exe`). You will need to set up your environment to use the
|
||||
|
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
@ -35,6 +35,7 @@ else()
|
||||
set(grpc_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
||||
endif()
|
||||
|
||||
|
@ -16,7 +16,7 @@ include (ExternalProject)
|
||||
|
||||
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
|
||||
set(PROTOBUF_URL https://github.com/google/protobuf.git)
|
||||
set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a)
|
||||
set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9)
|
||||
|
||||
if(WIN32)
|
||||
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
||||
|
@ -478,6 +478,10 @@ if (tensorflow_BUILD_CC_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*_test.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM tf_test_src_simple
|
||||
${tf_core_profiler_test_srcs}
|
||||
)
|
||||
|
||||
set(tf_test_lib tf_test_lib)
|
||||
add_library(${tf_test_lib} STATIC ${tf_src_testlib})
|
||||
|
||||
|
@ -40,6 +40,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
|
||||
@@rejection_resample
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@@sliding_window_batch
|
||||
@@sloppy_interleave
|
||||
@@unbatch
|
||||
|
||||
@ -72,6 +73,9 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
|
||||
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
||||
from tensorflow.contrib.data.python.ops.scan_ops import scan
|
||||
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
|
||||
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
|
||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||
from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_single_example
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -498,6 +498,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "slide_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["slide_dataset_op_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -45,12 +45,10 @@ class ResampleTest(test.TestCase):
|
||||
target_dist=target_dist,
|
||||
initial_dist=initial_dist,
|
||||
class_func=lambda c, _: c,
|
||||
seed=27)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
seed=27)).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
returned = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
|
@ -0,0 +1,242 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import sliding
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SlideDatasetTest(test.TestCase):
|
||||
|
||||
def testSlideDataset(self):
|
||||
"""Test an dataset that maps a TF function across its input elements."""
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7))
|
||||
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
window_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stride = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||
# RepeatDataset(count) -> _SlideDataset(window_size, stride).
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(_map_fn)
|
||||
.repeat(count)
|
||||
.apply(sliding.sliding_window_batch(window_size, stride))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Slide over a finite input, where the window_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
|
||||
# Same formula with convolution layer.
|
||||
num_batches = (20 * 7 - 14) // 7 + 1
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i*7 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, where the window_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
|
||||
|
||||
num_batches = (20 * 7 - 17) // 9 + 1
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(17):
|
||||
self.assertAllEqual(component[(i*9 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, which is less than window_size,
|
||||
# should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Empty window_size should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0})
|
||||
|
||||
# Invalid stride should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})
|
||||
|
||||
def assertSparseValuesEqual(self, a, b):
|
||||
self.assertAllEqual(a.indices, b.indices)
|
||||
self.assertAllEqual(a.values, b.values)
|
||||
self.assertAllEqual(a.dense_shape, b.dense_shape)
|
||||
|
||||
def testSlideSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
|
||||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||
dense_shape=[5, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSlideSparseWithDifferentDenseShapes(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=array_ops.expand_dims(
|
||||
math_ops.range(i, dtype=dtypes.int64), 1),
|
||||
values=array_ops.fill([math_ops.to_int32(i)], i),
|
||||
dense_shape=[i])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
|
||||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
for k in range(i * 3 + j):
|
||||
expected_indices.append([j, k])
|
||||
expected_values.append(i * 3 + j)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=expected_indices,
|
||||
values=expected_values,
|
||||
dense_shape=[5, i * 3 + 5 - 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNestedSlideSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(_sparse)
|
||||
.apply(sliding.sliding_window_batch(4, 2))
|
||||
.apply(sliding.sliding_window_batch(3, 1))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
# Slide: 1st batch.
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
|
||||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
|
||||
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
# Slide: 2nd batch.
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
|
||||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
|
||||
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSlideShapeError(self):
|
||||
|
||||
def generator():
|
||||
yield [1.0, 2.0, 3.0]
|
||||
yield [4.0, 5.0, 6.0]
|
||||
yield [7.0, 8.0, 9.0, 10.0]
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
|
||||
output_shapes=[None])
|
||||
.apply(sliding.sliding_window_batch(3, 1))
|
||||
.make_initializable_iterator())
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r"Cannot batch tensors with different shapes in component 0. "
|
||||
r"First element had shape \[3\] and element 2 had shape \[4\]."):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -106,6 +106,7 @@ py_library(
|
||||
"interleave_ops.py",
|
||||
"resampling.py",
|
||||
"scan_ops.py",
|
||||
"sliding.py",
|
||||
"stats_ops.py",
|
||||
"threadpool.py",
|
||||
"unique.py",
|
||||
|
@ -54,7 +54,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
dist_estimation_batch_size = 32
|
||||
target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist")
|
||||
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
|
||||
class_values_ds = dataset.map(class_func)
|
||||
if initial_dist is not None:
|
||||
initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
|
||||
@ -151,7 +151,7 @@ def _calculate_acceptance_probs(initial_probs, target_probs):
|
||||
```
|
||||
|
||||
|
||||
A solution for a_i in terms of the other variabes is the following:
|
||||
A solution for a_i in terms of the other variables is the following:
|
||||
```a_i = (t_i / p_i) / max_i[t_i / p_i]```
|
||||
"""
|
||||
# Add tiny to initial_probs to avoid divide by zero.
|
||||
|
102
tensorflow/contrib/data/python/ops/sliding.py
Normal file
102
tensorflow/contrib/data/python/ops/sliding.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Sliding dataset transformations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
class _SlideDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` that passes a sliding window over its input."""
|
||||
|
||||
def __init__(self, input_dataset, window_size, stride=1):
|
||||
"""See `sliding_window_batch` for details."""
|
||||
super(_SlideDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._window_size = ops.convert_to_tensor(
|
||||
window_size, dtype=dtypes.int64, name="window_size")
|
||||
self._stride = ops.convert_to_tensor(
|
||||
stride, dtype=dtypes.int64, name="stride")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.slide_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
stride=self._stride,
|
||||
output_shapes=nest.flatten(
|
||||
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
|
||||
output_types=nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)))
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
input_shapes = self._input_dataset.output_shapes
|
||||
return nest.pack_sequence_as(input_shapes, [
|
||||
tensor_shape.vector(None).concatenate(s)
|
||||
for s in nest.flatten(self._input_dataset.output_shapes)
|
||||
])
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
|
||||
def sliding_window_batch(window_size, stride=1):
|
||||
"""A sliding window with size of `window_size` and step of `stride`.
|
||||
|
||||
This transformation passes a sliding window over this dataset. The
|
||||
window size is `window_size` and step size is `stride`. If the left
|
||||
elements cannot fill up the sliding window, this transformation will
|
||||
drop the final smaller element. For example:
|
||||
|
||||
```python
|
||||
# NOTE: The following examples use `{ ... }` to represent the
|
||||
# contents of a dataset.
|
||||
a = { [1], [2], [3], [4], [5], [6] }
|
||||
|
||||
a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) ==
|
||||
{
|
||||
[[1], [2], [3]],
|
||||
[[3], [4], [5]],
|
||||
}
|
||||
```
|
||||
|
||||
Args:
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
elements in the sliding window.
|
||||
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
|
||||
steps moving the sliding window forward for one iteration. The default
|
||||
is `1`. It must be in `[1, window_size)`.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _SlideDataset(dataset, window_size, stride)
|
||||
|
||||
return _apply_fn
|
@ -224,7 +224,10 @@ py_test(
|
||||
srcs = ["python/ops/kmeans_test.py"],
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"], # b/67512932
|
||||
tags = [
|
||||
"nomac", # b/73741358
|
||||
"notsan", # b/67512932
|
||||
],
|
||||
deps = [
|
||||
":factorization_py",
|
||||
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||
|
@ -256,6 +256,9 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
|
||||
if (p != std::string::npos) {
|
||||
string rgb24 = line.substr(p + 9, line.find(" ", p + 9));
|
||||
rgb24 = rgb24.substr(0, rgb24.find(","));
|
||||
// Strip anything after " ", in case the format is
|
||||
// `640x360 [SAR 1:1 DAR 16:9]`
|
||||
rgb24 = rgb24.substr(0, rgb24.find(" "));
|
||||
string rgb24_width = rgb24.substr(0, rgb24.find("x"));
|
||||
string rgb24_height = rgb24.substr(rgb24_width.length() + 1);
|
||||
if (strings::safe_strtou32(rgb24_width, &width_value) &&
|
||||
@ -270,8 +273,10 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
|
||||
// We only look for the first stream mapping to have the number of the
|
||||
// frames.
|
||||
// Once processed we will not further process stream mapping section.
|
||||
if (line.find("frame= ") == 0) {
|
||||
string number = line.substr(8, line.find(" ", 8));
|
||||
if (line.find("frame=") == 0) {
|
||||
// The format might be `frame= 166 ` or `frame=12488 `
|
||||
string number = line.substr(6);
|
||||
number = number.substr(number.find_first_not_of(" "));
|
||||
number = number.substr(0, number.find(" "));
|
||||
if (strings::safe_strtou32(number, &frames_value)) {
|
||||
in_mapping = false;
|
||||
|
@ -142,7 +142,7 @@ def arg_scope(list_ops_or_scope, **kwargs):
|
||||
else:
|
||||
# Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
|
||||
if not isinstance(list_ops_or_scope, (list, tuple)):
|
||||
raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
|
||||
raise TypeError('list_ops_or_scope must either be a list/tuple or reused '
|
||||
'scope (i.e. dict)')
|
||||
try:
|
||||
current_scope = current_arg_scope().copy()
|
||||
|
@ -321,7 +321,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||
|
||||
NOTE: This function consumes images, computes their logits, and then
|
||||
computes the classifier score. If you would like to precompute many logits for
|
||||
large batches, use clasifier_score_from_logits(), which this method also
|
||||
large batches, use classifier_score_from_logits(), which this method also
|
||||
uses.
|
||||
|
||||
Args:
|
||||
@ -454,7 +454,7 @@ def frechet_classifier_distance(real_images,
|
||||
|
||||
This technique is described in detail in https://arxiv.org/abs/1706.08500.
|
||||
Given two Gaussian distribution with means m and m_w and covariance matrices
|
||||
C and C_w, this function calcuates
|
||||
C and C_w, this function calculates
|
||||
|
||||
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
|
||||
|
||||
@ -467,7 +467,7 @@ def frechet_classifier_distance(real_images,
|
||||
Frechet distance is biased. It is more biased for small sample sizes. (e.g.
|
||||
even if the two distributions are the same, for a small sample size, the
|
||||
expected Frechet distance is large). It is important to use the same
|
||||
sample size to compute frechet classifier distance when comparing two
|
||||
sample size to compute Frechet classifier distance when comparing two
|
||||
generative models.
|
||||
|
||||
NOTE: This function consumes images, computes their activations, and then
|
||||
@ -659,7 +659,7 @@ def frechet_classifier_distance_from_activations(real_activations,
|
||||
|
||||
This technique is described in detail in https://arxiv.org/abs/1706.08500.
|
||||
Given two Gaussian distribution with means m and m_w and covariance matrices
|
||||
C and C_w, this function calcuates
|
||||
C and C_w, this function calculates
|
||||
|
||||
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
|
||||
|
||||
|
@ -212,7 +212,7 @@ def sliced_wasserstein_distance(real_images,
|
||||
Args:
|
||||
real_images: (tensor) Real images (batch, height, width, channels).
|
||||
fake_images: (tensor) Fake images (batch, height, width, channels).
|
||||
resolution_min: (int) Minimum resolution for the Laplacion pyramid.
|
||||
resolution_min: (int) Minimum resolution for the Laplacian pyramid.
|
||||
patches_per_image: (int) Number of patches to extract per image per
|
||||
Laplacian level.
|
||||
patch_size: (int) Width of a square patch.
|
||||
@ -221,7 +221,7 @@ def sliced_wasserstein_distance(real_images,
|
||||
use_svd: experimental method to compute a more accurate distance.
|
||||
Returns:
|
||||
List of tuples (distance_real, distance_fake) for each level of the
|
||||
Laplacian pyramid from the highest resoluion to the lowest.
|
||||
Laplacian pyramid from the highest resolution to the lowest.
|
||||
distance_real is the Wasserstein distance between real images
|
||||
distance_fake is the Wasserstein distance between real and fake images.
|
||||
Raises:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Miscellanous utilities for TFGAN code and examples.
|
||||
"""Miscellaneous utilities for TFGAN code and examples.
|
||||
|
||||
Includes:
|
||||
1) Conditioning the value of a Tensor, based on techniques from
|
||||
|
@ -17,7 +17,7 @@
|
||||
We use this to keep a history of values created by a generator, such that
|
||||
a discriminator can randomly be trained on some older samples, not just the
|
||||
current one. This can help to not let the discriminator get too far ahead of the
|
||||
generator and also to keep the system from oscilating, if the discriminator
|
||||
generator and also to keep the system from oscillating, if the discriminator
|
||||
forgets too fast what past samples from the generator looked like.
|
||||
|
||||
See the following papers for more details.
|
||||
@ -97,7 +97,7 @@ def tensor_pool(input_values,
|
||||
dtypes=[v.dtype for v in input_values],
|
||||
shapes=None)
|
||||
|
||||
# In pseudeo code this code does the following:
|
||||
# In pseudo code this code does the following:
|
||||
# if not pool_full:
|
||||
# enqueue(input_values)
|
||||
# return input_values
|
||||
|
@ -148,7 +148,7 @@ class VirtualBatchnormTest(test.TestCase):
|
||||
self.assertAllClose(bn_np[i, ...], vb_np)
|
||||
|
||||
def test_minibatch_independent(self):
|
||||
"""Test that virtual batch normalized exampels are independent.
|
||||
"""Test that virtual batch normalized examples are independent.
|
||||
|
||||
Unlike batch normalization, virtual batch normalization has the property
|
||||
that the virtual batch normalized value of an example is independent of the
|
||||
|
@ -110,7 +110,7 @@ class GridRNNCell(rnn.RNNCell):
|
||||
logging.warning('%s: Using a concatenated state is slower and will '
|
||||
'soon be deprecated. Use state_is_tuple=True.', self)
|
||||
if not output_is_tuple:
|
||||
logging.warning('%s: Using a concatenated output is slower and will'
|
||||
logging.warning('%s: Using a concatenated output is slower and will '
|
||||
'soon be deprecated. Use output_is_tuple=True.', self)
|
||||
|
||||
if num_dims < 1:
|
||||
|
@ -259,6 +259,7 @@ cuda_py_test(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
data = [":sparse_image_warp_test_data"],
|
||||
tags = ["no_pip"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -1,66 +1,93 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
tf_kernel_library(
|
||||
name = "kafka_kernels",
|
||||
srcs = ["kernels/kafka_dataset_ops.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:dataset",
|
||||
"//third_party/eigen3",
|
||||
"@kafka",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["kafka_ops"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_kafka_ops",
|
||||
out = "python/ops/gen_kafka_ops.py",
|
||||
require_shape_functions = True,
|
||||
deps = [":kafka_ops_op_lib"],
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
"tf_custom_op_library",
|
||||
"tf_custom_op_py_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_py_test",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kafka",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "_dataset_ops.so",
|
||||
srcs = ["ops/dataset_ops.cc"],
|
||||
deps = [":dataset_kernels"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["dataset_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dataset_kernels",
|
||||
srcs = ["kernels/kafka_dataset_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@kafka",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset_ops",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/kafka_dataset_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gen_kafka_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
":kafka_op_loader",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_dataset_ops",
|
||||
out = "python/ops/gen_dataset_ops.py",
|
||||
deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dataset_ops_kernels",
|
||||
deps = [
|
||||
":dataset_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "kafka_op_loader",
|
||||
srcs = ["python/ops/kafka_op_loader.py"],
|
||||
dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"],
|
||||
kernels = [
|
||||
":dataset_ops_kernels",
|
||||
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_dataset_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
@ -88,6 +115,7 @@ tf_py_test(
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"no_windows",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
@ -95,7 +123,9 @@ tf_py_test(
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
include = [
|
||||
"**/*",
|
||||
],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
|
@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/dataset.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
#include "src-cpp/rdkafkacpp.h"
|
||||
|
||||
|
44
tensorflow/contrib/kafka/ops/dataset_ops.cc
Normal file
44
tensorflow/contrib/kafka/ops/dataset_ops.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("KafkaDataset")
|
||||
.Input("topics: string")
|
||||
.Input("servers: string")
|
||||
.Input("group: string")
|
||||
.Input("eof: bool")
|
||||
.Input("timeout: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Creates a dataset that emits the messages of one or more Kafka topics.
|
||||
|
||||
topics: A `tf.string` tensor containing one or more subscriptions,
|
||||
in the format of [topic:partition:offset:length],
|
||||
by default length is -1 for unlimited.
|
||||
servers: A list of bootstrap servers.
|
||||
group: The consumer group id.
|
||||
eof: If True, the kafka reader will stop on EOF.
|
||||
timeout: The timeout value for the Kafka Consumer to wait
|
||||
(in millisecond).
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -17,8 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kafka.python.ops import gen_kafka_ops
|
||||
from tensorflow.python.data.ops.readers import Dataset
|
||||
from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
|
||||
from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.data.ops.dataset_ops import Dataset
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -58,8 +59,8 @@ class KafkaDataset(Dataset):
|
||||
timeout, dtype=dtypes.int64, name="timeout")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group,
|
||||
self._eof, self._timeout)
|
||||
return gen_dataset_ops.kafka_dataset(self._topics, self._servers,
|
||||
self._group, self._eof, self._timeout)
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
|
24
tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
Normal file
24
tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python helper for loading kafka ops and kernels."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_dataset_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
|
@ -153,7 +153,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
raise ValueError("Unsupported momentum type {}. Must be one of {}."
|
||||
.format(momentum_type, legal_momentum_types))
|
||||
if momentum_type != "regular" and norm_constraint is not None:
|
||||
raise ValueError("Update clipping is only supported with momentum"
|
||||
raise ValueError("Update clipping is only supported with momentum "
|
||||
"type 'regular'.")
|
||||
if momentum_type not in ["regular", "adam"] and momentum != 0:
|
||||
raise ValueError("Momentum must be unspecified if using a momentum_type "
|
||||
|
@ -470,7 +470,7 @@ def embedding_lookup_unique(params, ids, name=None):
|
||||
ids = ops.convert_to_tensor(ids)
|
||||
shape = array_ops.shape(ids)
|
||||
ids_flat = array_ops.reshape(
|
||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||
ids, math_ops.reduce_prod(shape, keepdims=True))
|
||||
unique_ids, idx = array_ops.unique(ids_flat)
|
||||
unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
|
||||
embeds_flat = array_ops.gather(unique_embeddings, idx)
|
||||
|
@ -125,7 +125,7 @@ def embed_sequence(ids,
|
||||
`reuse` is `None` or `False`.
|
||||
"""
|
||||
if not (reuse or (vocab_size and embed_dim)):
|
||||
raise ValueError('Must specify vocab size and embedding dimension when not'
|
||||
raise ValueError('Must specify vocab size and embedding dimension when not '
|
||||
'reusing. Got vocab_size=%s and embed_dim=%s' % (
|
||||
vocab_size, embed_dim))
|
||||
with variable_scope.variable_scope(
|
||||
|
@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(default_visibility = [
|
||||
"//engedu/ml/tf_from_scratch:__pkg__",
|
||||
"//tensorflow:internal",
|
||||
@ -426,7 +428,10 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/kmeans_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["noasan"],
|
||||
tags = [
|
||||
"noasan", # b/73741358
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -917,8 +917,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
|
||||
if feed_fn:
|
||||
hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn))
|
||||
if steps == 0:
|
||||
logging.warning('evaluation steps are 0. If `input_fn` does not raise'
|
||||
'OutOfRangeError`, the evaluation will never stop.'
|
||||
logging.warning('evaluation steps are 0. If `input_fn` does not raise '
|
||||
'`OutOfRangeError`, the evaluation will never stop. '
|
||||
'Use steps=None if intended.')
|
||||
if steps:
|
||||
hooks.append(
|
||||
|
@ -358,7 +358,7 @@ class Experiment(object):
|
||||
self._start_server()
|
||||
elif config.cluster_spec and config.master:
|
||||
raise ValueError(
|
||||
"For distributed runtime, Experiment class only works with"
|
||||
"For distributed runtime, Experiment class only works with "
|
||||
"tf.contrib.learn.RunConfig for now, but provided {}".format(
|
||||
type(config)))
|
||||
|
||||
|
@ -61,7 +61,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'):
|
||||
ids = ops.convert_to_tensor(ids)
|
||||
shape = array_ops_.shape(ids)
|
||||
ids_flat = array_ops_.reshape(
|
||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||
ids, math_ops.reduce_prod(shape, keepdims=True))
|
||||
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
|
||||
embed_shape = array_ops_.concat([shape, [-1]], 0)
|
||||
embeds = array_ops_.reshape(embeds_flat, embed_shape)
|
||||
|
@ -27,10 +27,10 @@ LIBDIR := $(MAKEFILE_DIR)/gen/lib/
|
||||
GENDIR := $(MAKEFILE_DIR)/gen/obj/
|
||||
|
||||
# Settings for the host compiler.
|
||||
CXX := $(CC_PREFIX) gcc
|
||||
CXX := $(CC_PREFIX)gcc
|
||||
CXXFLAGS := --std=c++11 -O3 -DNDEBUG
|
||||
CC := $(CC_PREFIX) gcc
|
||||
CFLAGS :=
|
||||
CC := $(CC_PREFIX)gcc
|
||||
CFLAGS := -O3 -DNDEBUG
|
||||
LDOPTS :=
|
||||
LDOPTS += -L/usr/local/lib
|
||||
ARFLAGS := -r
|
||||
@ -57,10 +57,11 @@ LIBS := \
|
||||
|
||||
# If we're on Linux, also link in the dl library.
|
||||
ifeq ($(HOST_OS),LINUX)
|
||||
LIBS += -ldl -lpthread
|
||||
LIBS += -ldl
|
||||
endif
|
||||
|
||||
include $(MAKEFILE_DIR)/ios_makefile.inc
|
||||
include $(MAKEFILE_DIR)/rpi_makefile.inc
|
||||
|
||||
# This library is the main target for this makefile. It will contain a minimal
|
||||
# runtime that can be linked in to other programs.
|
||||
|
@ -99,7 +99,7 @@ Similar to the Android demo app, there's an iOS camera app that uses exactly the
|
||||
|
||||
This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app:
|
||||
|
||||
1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app.
|
||||
1. Run `tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app.
|
||||
1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`.
|
||||
1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file.
|
||||
1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode.
|
||||
@ -165,7 +165,7 @@ bazel-bin/tensorflow/python/tools/freeze_graph\
|
||||
--input_graph=/tmp/mobilenet_v1_224.pb \
|
||||
--input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
|
||||
--input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
|
||||
--output_node_names=MobileNet/Predictions/Reshape_1
|
||||
--output_node_names=MobilenetV1/Predictions/Reshape_1
|
||||
```
|
||||
|
||||
The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with
|
||||
|
@ -33,7 +33,7 @@ class AllocationInfo;
|
||||
// each tensor needs to be allocated and deallocated, and preallocates all the
|
||||
// necessary memory (the PlanAllocations phase). It then assigns portions of
|
||||
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
|
||||
// share some of the bufer if a tensor B is to be allocated after another tensor
|
||||
// share some of the buffer if a tensor B is to be allocated after another tensor
|
||||
// A has been deallocated.
|
||||
//
|
||||
// If dynamic tensors are used the planning steps can be repeated during model
|
||||
|
22
tensorflow/contrib/lite/build_rpi_lib.sh
Executable file
22
tensorflow/contrib/lite/build_rpi_lib.sh
Executable file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash -x
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR/../../.."
|
||||
|
||||
CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7
|
@ -30,7 +30,7 @@ namespace tflite {
|
||||
// va_list args;
|
||||
// foo.Report("test %d", args); // where args is va_list
|
||||
//
|
||||
// Sublclass ErrorReporter to provide another reporting destination.
|
||||
// Subclass ErrorReporter to provide another reporting destination.
|
||||
// For example, if you have a GUI program, you might redirect to a buffer
|
||||
// that drives a GUI error log box.
|
||||
class ErrorReporter {
|
||||
|
@ -22,6 +22,15 @@ Then install
|
||||
brew install automake
|
||||
brew install libtool
|
||||
```
|
||||
If you get an error where either automake or libtool install but do not link correctly, you'll first need to:
|
||||
```bash
|
||||
sudo chown -R $(whoami) /usr/local/*
|
||||
```
|
||||
Then follow the instructions to perform the linking:
|
||||
```bash
|
||||
brew link automake
|
||||
brew link libtool
|
||||
```
|
||||
|
||||
Then you need to run a shell script to download the dependencies you need:
|
||||
|
||||
|
50
tensorflow/contrib/lite/g3doc/rpi.md
Normal file
50
tensorflow/contrib/lite/g3doc/rpi.md
Normal file
@ -0,0 +1,50 @@
|
||||
# TensorFlow Lite for Raspberry Pi
|
||||
|
||||
## Cross compiling
|
||||
### Installing toolchian
|
||||
This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
|
||||
|
||||
To cross compiling TensorFlow Lite. First you should install the toolchain and libs.
|
||||
```bash
|
||||
sudo apt-get update
|
||||
sudo apt-get install crossbuild-essential-armhf
|
||||
```
|
||||
> If you are using docker, you may not use `sudo`
|
||||
|
||||
### Building
|
||||
Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
|
||||
> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
|
||||
```bash
|
||||
./tensorflow/contrib/lite/download_dependencies.sh
|
||||
```
|
||||
Note than you only need to to this once.
|
||||
|
||||
You should then be able to compile:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/build_rpi_lib.sh
|
||||
```
|
||||
|
||||
This should compile a static library in:
|
||||
`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
|
||||
|
||||
## Native compiling
|
||||
This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
|
||||
|
||||
Log in to you RPI, install the toolchain.
|
||||
```bash
|
||||
sudo apt-get instal build-essential
|
||||
```
|
||||
|
||||
First, clone this TensorFlow repository. Run this at the root of the repository:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/download_dependencies.sh
|
||||
```
|
||||
Note than you only need to to this once.
|
||||
|
||||
You should then be able to compile:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/build_rpi_lib.sh
|
||||
```
|
||||
|
||||
This should compile a static library in:
|
||||
`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
|
@ -493,7 +493,7 @@ class Interpreter {
|
||||
// During Invoke(), Interpreter will allocate input tensors first, which are
|
||||
// known to be fixed size. Then it will allocate outputs from nodes as many
|
||||
// as possible. When there is a node that produces dynamic sized tensor.
|
||||
// Intepreter will stop allocating tensors, set the value of next allocate
|
||||
// Interpreter will stop allocating tensors, set the value of next allocate
|
||||
// node id, and execute the node to generate the output tensor before continue
|
||||
// to allocate successors. This process repeats until all nodes are executed.
|
||||
// NOTE: this relies on the order of nodes that is in topological order.
|
||||
|
@ -42,7 +42,7 @@ TEST(BasicInterpreter, InvokeInvalidModel) {
|
||||
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
// Test size accesser functions.
|
||||
// Test size accessor functions.
|
||||
TEST(BasicInterpreter, TestSizeFunctions) {
|
||||
Interpreter interpreter;
|
||||
int base_index;
|
||||
|
@ -64,7 +64,7 @@ struct OpData {
|
||||
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -52,7 +52,7 @@ enum KernelType {
|
||||
struct OpData {
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -48,7 +48,7 @@ enum KernelType {
|
||||
|
||||
struct OpData {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "third_party/fft2d/fft.h"
|
||||
|
@ -58,7 +58,7 @@ inline bool IsConstantTensor(TfLiteTensor* tensor) {
|
||||
}
|
||||
|
||||
// Determines whether tensor is dynamic. Note that a tensor can be non-const and
|
||||
// not dynamic. This function specificially checks for a dynamic tensor.
|
||||
// not dynamic. This function specifically checks for a dynamic tensor.
|
||||
inline bool IsDynamicTensor(TfLiteTensor* tensor) {
|
||||
return tensor->allocation_type == kTfLiteDynamic;
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// LSH Projection projects an input to a bit vector via locality senstive
|
||||
// LSH Projection projects an input to a bit vector via locality sensitive
|
||||
// hashing.
|
||||
//
|
||||
// Options:
|
||||
|
@ -213,9 +213,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
// present.
|
||||
// 2) If projection weight is present, then projection bias is optional.
|
||||
// TODO(ghodrat): make sure this is correct.
|
||||
const bool projecton_tensors_consistent =
|
||||
const bool projection_tensors_consistent =
|
||||
((projection_weights != nullptr) || (projection_bias == nullptr));
|
||||
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
|
||||
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -357,7 +357,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
|
@ -49,20 +49,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
|
||||
int num_output_elements = 1;
|
||||
int strech_dim = -1;
|
||||
int stretch_dim = -1;
|
||||
for (int i = 0; i < params->num_dimensions; ++i) {
|
||||
int value = params->shape[i];
|
||||
if (value == -1) {
|
||||
TF_LITE_ENSURE_EQ(context, strech_dim, -1);
|
||||
strech_dim = i;
|
||||
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
|
||||
stretch_dim = i;
|
||||
} else {
|
||||
num_output_elements *= value;
|
||||
output_size->data[i] = value;
|
||||
}
|
||||
}
|
||||
if (strech_dim != -1) {
|
||||
output_size->data[strech_dim] = num_input_elements / num_output_elements;
|
||||
num_output_elements *= output_size->data[strech_dim];
|
||||
if (stretch_dim != -1) {
|
||||
output_size->data[stretch_dim] = num_input_elements / num_output_elements;
|
||||
num_output_elements *= output_size->data[stretch_dim];
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
|
||||
|
@ -60,7 +60,7 @@ TEST(ReshapeOpTest, TooManyDimensions) {
|
||||
|
||||
TEST(ReshapeOpTest, TooManySpecialDimensions) {
|
||||
EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}),
|
||||
"strech_dim != -1");
|
||||
"stretch_dim != -1");
|
||||
}
|
||||
|
||||
TEST(ReshapeOpTest, SimpleTest) {
|
||||
|
@ -141,8 +141,8 @@ void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
|
||||
|
||||
void SingleOpModel::SetCustomOp(
|
||||
const string& name, const std::vector<uint8_t>& custom_option,
|
||||
const std::function<TfLiteRegistration*()>& registeration) {
|
||||
custom_registrations_[name] = registeration;
|
||||
const std::function<TfLiteRegistration*()>& registration) {
|
||||
custom_registrations_[name] = registration;
|
||||
opcodes_.push_back(
|
||||
CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data()));
|
||||
operators_.push_back(CreateOperator(
|
||||
|
@ -360,7 +360,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
|
@ -34,8 +34,8 @@ class MemoryPlanner {
|
||||
// [first_node, last_node].
|
||||
virtual TfLiteStatus ExecuteAllocations(int first_node, int last_node) = 0;
|
||||
|
||||
// Invalidates allocations made earliers. This is called when tensors sizes
|
||||
// have change. All planned allocations remain, but can't be used until
|
||||
// Invalidates allocations made earlier. This is called when tensors sizes
|
||||
// have changed. All planned allocations remain, but can't be used until
|
||||
// ExecuteAllocations() is called.
|
||||
virtual TfLiteStatus ResetAllocations() = 0;
|
||||
};
|
||||
|
@ -81,7 +81,7 @@ class FlatBufferModel {
|
||||
const tflite::Model* model_spec,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
// Releases memory or unmaps mmaped meory.
|
||||
// Releases memory or unmaps mmaped memory.
|
||||
~FlatBufferModel();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
|
@ -569,7 +569,7 @@ enum {
|
||||
ANEURALNETWORKS_LOGISTIC = 14,
|
||||
|
||||
/**
|
||||
* Projects an input to a bit vector via locality senstive hashing.
|
||||
* Projects an input to a bit vector via locality sensitive hashing.
|
||||
*
|
||||
* Inputs:
|
||||
* * 0: Hash functions. Dim.size == 2, DataType: Float.
|
||||
|
33
tensorflow/contrib/lite/rpi_makefile.inc
Normal file
33
tensorflow/contrib/lite/rpi_makefile.inc
Normal file
@ -0,0 +1,33 @@
|
||||
# Settings for Raspberry Pi.
|
||||
ifeq ($(TARGET), RPI)
|
||||
ifeq ($(TARGET_ARCH), armv7)
|
||||
CXXFLAGS += \
|
||||
-march=armv7-a \
|
||||
-mfpu=neon-vfpv4 \
|
||||
-funsafe-math-optimizations \
|
||||
-ftree-vectorize
|
||||
|
||||
CCFLAGS += \
|
||||
-march=armv7-a \
|
||||
-mfpu=neon-vfpv4 \
|
||||
-funsafe-math-optimizations \
|
||||
-ftree-vectorize
|
||||
|
||||
LDFLAGS := \
|
||||
-Wl,--no-export-dynamic \
|
||||
-Wl,--exclude-libs,ALL \
|
||||
-Wl,--gc-sections \
|
||||
-Wl,--as-needed
|
||||
endif
|
||||
|
||||
LIBS := \
|
||||
-lstdc++ \
|
||||
-lpthread \
|
||||
-lm \
|
||||
-ldl
|
||||
|
||||
OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/
|
||||
LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/
|
||||
BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/
|
||||
DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/
|
||||
endif
|
@ -39,8 +39,8 @@ import tensorflow as tf
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to move TFLite models from pre-release schema to"
|
||||
" new schema.")
|
||||
description="Script to move TFLite models from pre-release schema to "
|
||||
"new schema.")
|
||||
parser.add_argument(
|
||||
"input",
|
||||
type=str,
|
||||
@ -48,7 +48,7 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"output",
|
||||
type=str,
|
||||
help="Output json or bin TensorFlow lite model compliant with"
|
||||
help="Output json or bin TensorFlow lite model compliant with "
|
||||
"the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
|
||||
|
||||
|
||||
@ -258,7 +258,7 @@ class Converter(object):
|
||||
# Check if builtin_code is the appropriate string type
|
||||
# use type("") instead of str or unicode. for py2and3
|
||||
if not isinstance(operator_code["builtin_code"], type(u"")):
|
||||
raise ValueError("builtin_code %r is non-string. this usually means"
|
||||
raise ValueError("builtin_code %r is non-string. this usually means "
|
||||
"your model has consistency problems." %
|
||||
(operator_code["builtin_code"]))
|
||||
operator_code["builtin_code"] = (RemapOperator(
|
||||
|
@ -113,21 +113,21 @@ TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) {
|
||||
underlying_buffer_size_ = required_size;
|
||||
underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr;
|
||||
}
|
||||
commited_ = true;
|
||||
committed_ = true;
|
||||
return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context,
|
||||
const ArenaAlloc& alloc,
|
||||
char** output_ptr) {
|
||||
TF_LITE_ENSURE(context, commited_);
|
||||
TF_LITE_ENSURE(context, committed_);
|
||||
TF_LITE_ENSURE(context, output_ptr != nullptr);
|
||||
*output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SimpleMemoryArena::Clear() {
|
||||
commited_ = false;
|
||||
committed_ = false;
|
||||
high_water_mark_ = 0;
|
||||
allocs_.clear();
|
||||
return kTfLiteOk;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
|
||||
// This little structure holds the offset and the size for a dynamic memory
|
||||
// allocation in the memory arena. When the arena is commited and the
|
||||
// allocation in the memory arena. When the arena is committed and the
|
||||
// underlying buffer is set, the alloc can be resolved into an actual memory
|
||||
// pointer.
|
||||
struct ArenaAlloc {
|
||||
@ -43,7 +43,7 @@ struct ArenaAlloc {
|
||||
class SimpleMemoryArena {
|
||||
public:
|
||||
explicit SimpleMemoryArena(size_t arena_alignment)
|
||||
: commited_(false),
|
||||
: committed_(false),
|
||||
arena_alignment_(arena_alignment),
|
||||
high_water_mark_(0),
|
||||
underlying_buffer_size_(0),
|
||||
@ -73,7 +73,7 @@ class SimpleMemoryArena {
|
||||
}
|
||||
|
||||
private:
|
||||
bool commited_;
|
||||
bool committed_;
|
||||
size_t arena_alignment_;
|
||||
size_t high_water_mark_;
|
||||
std::unique_ptr<char[]> underlying_buffer_;
|
||||
|
@ -194,6 +194,8 @@ with:
|
||||
srcs = glob(["libs/arm64-v8a/*.so"]),
|
||||
```
|
||||
|
||||
If you are building for Android TV (Shield TV devices), replace "portrait" with "landscape" for android:screenOrientation in all four activities in tensorflow/examples/android/AndroidManifest.xml
|
||||
|
||||
Then run:
|
||||
```bash
|
||||
# Create dir for native libs
|
||||
|
@ -80,10 +80,9 @@ if [[ ! -z "${OPTIMIZE_FOR_GRAPH}" ]]; then
|
||||
fi
|
||||
else
|
||||
echo "${PRNT_SLCTV_BIN} found. Using it"
|
||||
${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h
|
||||
|
||||
fi
|
||||
|
||||
${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h
|
||||
fi
|
||||
|
||||
if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then
|
||||
@ -111,7 +110,7 @@ if [[ -z "${BUILD_ARCH}" ]]; then
|
||||
TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios`
|
||||
else
|
||||
# arch specified so build just that
|
||||
TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios -a ${BUILD_ARCH}`
|
||||
TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios -a "${BUILD_ARCH}"`
|
||||
fi
|
||||
export HOST_NSYNC_LIB TARGET_NSYNC_LIB
|
||||
|
||||
|
@ -3647,7 +3647,7 @@ def cohen_kappa(labels,
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
|
||||
raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported '
|
||||
'when eager execution is enabled.')
|
||||
if num_classes < 2:
|
||||
raise ValueError('`num_classes` must be >= 2.'
|
||||
|
@ -214,7 +214,7 @@ def masked_convolution(inputs,
|
||||
elif data_format == 'NCHW':
|
||||
df = 'channels_first'
|
||||
else:
|
||||
raise ValueError('Unsupported data fromat', data_format)
|
||||
raise ValueError('Unsupported data format', data_format)
|
||||
|
||||
layer = layer_class(
|
||||
filters=num_outputs,
|
||||
|
@ -216,7 +216,7 @@ def _partitioned_variable_assign(partitioned_var, new_value):
|
||||
"""Assign op for partitioned variables.
|
||||
|
||||
Args:
|
||||
partitioned_var: A partitioned tensotflow variable
|
||||
partitioned_var: A partitioned tensorflow variable
|
||||
new_value: Value to be assigned to the variable var
|
||||
|
||||
Returns:
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
// Skip MPI C++ bindings support, this matches the usage in other places
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include "third_party/mpi/mpi.h"
|
||||
#define MPI_CHECK(cmd) \
|
||||
do { \
|
||||
|
@ -53,7 +53,7 @@ def from_contrib_estimator(estimator,
|
||||
`Estimator`.
|
||||
"""
|
||||
if isinstance(estimator, core_estimator.Estimator):
|
||||
raise TypeError('Espected estimator to be of type '
|
||||
raise TypeError('Expected estimator to be of type '
|
||||
'tf.contrib.learn.Estimator, but got type '
|
||||
'tf.python.estimator.Estimator. You likely want to call '
|
||||
'from_estimator.')
|
||||
@ -88,7 +88,7 @@ def from_estimator(estimator,
|
||||
`Estimator`.
|
||||
"""
|
||||
if isinstance(estimator, contrib_estimator.Estimator):
|
||||
raise TypeError('Espected estimator to be of type '
|
||||
raise TypeError('Expected estimator to be of type '
|
||||
'tf.python.estimator.Estimator, but got type '
|
||||
'tf.contrib.learn.Estimator. You likely want to call '
|
||||
'from_contrib_estimator.')
|
||||
|
@ -212,7 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
|
||||
|
||||
def __init__(self):
|
||||
self.cant_return = False
|
||||
super(gast.NodeVisitor, self).__init__()
|
||||
super(DetectReturnInUnsupportedControlFlow, self).__init__()
|
||||
|
||||
def visit_While(self, node):
|
||||
self.cant_return = True
|
||||
|
@ -237,7 +237,7 @@ def _FindFusedBatchNorms(graph):
|
||||
# The batch variance used during forward and backward prop is biased,
|
||||
# i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
|
||||
# calculation, the variance is corrected by the term N/N-1 (Bessel's
|
||||
# correction). The variance tensor read from FuseBatchNorm has bessel's
|
||||
# correction). The variance tensor read from FuseBatchNorm has Bessel's
|
||||
# correction applied, so we undo it here.
|
||||
scope, sep, _ = bn_op.name.rpartition('/')
|
||||
g = ops.get_default_graph()
|
||||
@ -306,7 +306,7 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
|
||||
|
||||
Args:
|
||||
context: The scope under which we look for batch norm params
|
||||
match: Object containg required batch norm tensors for correction
|
||||
match: Object containing required batch norm tensors for correction
|
||||
computation.
|
||||
freeze_batch_norm_delay: Delay in steps at which computation switches
|
||||
from regular batch norm to frozen mean and variance.
|
||||
|
@ -282,8 +282,8 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
|
||||
Args:
|
||||
inputs: a tensor containing values to be quantized.
|
||||
min_var: a variable containing quantization range lower end(s).
|
||||
max_var: a variable containing quantization range lupper end(s).
|
||||
per_channel: a boolean specifying whether to use per-channel quantizatioh.
|
||||
max_var: a variable containing quantization range upper end(s).
|
||||
per_channel: a boolean specifying whether to use per-channel quantization.
|
||||
num_bits: Number of bits to use for quantization, must be between 2 and 8.
|
||||
narrow_range: Whether to use the narrow quantization range
|
||||
[1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
|
||||
|
@ -341,7 +341,7 @@ def _InsertQuantOp(context,
|
||||
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||||
|
||||
Args:
|
||||
context: Context w,here producer and consumer operations are nested.
|
||||
context: Context where producer and consumer operations are nested.
|
||||
name: Name for the new quantization op within the context.
|
||||
producer: Producer operation of the pairs where quantization will be
|
||||
inserted.
|
||||
|
@ -155,7 +155,7 @@ def experimental_create_training_graph(input_graph=None,
|
||||
often fail.
|
||||
|
||||
Args:
|
||||
input_graph: The tf.Graph to be transformed,if None then defaults to the
|
||||
input_graph: The tf.Graph to be transformed, if None then defaults to the
|
||||
default graph.
|
||||
weight_bits: Number of bits to use for quantizing weights.
|
||||
activation_bits: Number of bits to use for quantizing activations.
|
||||
|
@ -419,7 +419,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -470,7 +470,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -526,7 +526,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -565,7 +565,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
stddev: Standard deviation of normal variable.
|
||||
|
||||
Returns:
|
||||
An initialized that initialzes with a truncated normal variable.
|
||||
An initialized that initializes with a truncated normal variable.
|
||||
"""
|
||||
return init_ops.truncated_normal_initializer(stddev=stddev)
|
||||
|
||||
|
@ -197,7 +197,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
stddev: Standard deviation of normal variable.
|
||||
|
||||
Returns:
|
||||
An initialized that initialzes with a truncated normal variable.
|
||||
An initialized that initializes with a truncated normal variable.
|
||||
"""
|
||||
return init_ops.truncated_normal_initializer(stddev=stddev)
|
||||
|
||||
|
@ -69,7 +69,7 @@ Element-wise dot product of a and b is represented by ab
|
||||
Element-wise dot product is represented by \circ
|
||||
Matrix multiplication is represented by *
|
||||
|
||||
Baises are initialized with :
|
||||
Biases are initialized with :
|
||||
`b_ru` - constant_initializer(1.0)
|
||||
`b_c` - constant_initializer(0.0)
|
||||
|
||||
|
@ -54,7 +54,7 @@ def blocks_match(sess, use_peephole):
|
||||
initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212)
|
||||
|
||||
with variable_scope.variable_scope("test", initializer=initializer):
|
||||
# magic naming so that the cells pick up these variables and resuse them
|
||||
# magic naming so that the cells pick up these variables and reuse them
|
||||
if use_peephole:
|
||||
wci = variable_scope.get_variable(
|
||||
"rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32)
|
||||
|
@ -480,8 +480,7 @@ class LSTMBlockWrapper(base_layer.Layer):
|
||||
"""Run this LSTM on inputs, starting from the given state.
|
||||
|
||||
Args:
|
||||
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
|
||||
or a list of `time_len` tensors of shape `[batch_size, input_size]`.
|
||||
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
|
||||
initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
|
||||
of shape `[batch_size, self._num_units]`. If this is not provided, the
|
||||
cell is expected to create a zero initial state of type `dtype`.
|
||||
|
@ -534,7 +534,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
|
||||
initializer: (optional) The initializer to use for the weight and
|
||||
projection matrices, default None.
|
||||
num_unit_shards: (optional) int, default 1, How to split the weight
|
||||
matrix. If > 1,the weight matrix is stored across num_unit_shards.
|
||||
matrix. If > 1, the weight matrix is stored across num_unit_shards.
|
||||
forget_bias: (optional) float, default 1.0, The initial bias of the
|
||||
forget gates, used to reduce the scale of forgetting at the beginning
|
||||
of the training.
|
||||
@ -993,7 +993,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
||||
initializer: (optional) The initializer to use for the weight and
|
||||
projection matrices, default None.
|
||||
num_unit_shards: (optional) int, default 1, How to split the weight
|
||||
matrix. If > 1,the weight matrix is stored across num_unit_shards.
|
||||
matrix. If > 1, the weight matrix is stored across num_unit_shards.
|
||||
forget_bias: (optional) float, default 1.0, The initial bias of the
|
||||
forget gates, used to reduce the scale of forgetting at the beginning
|
||||
of the training.
|
||||
@ -2133,7 +2133,7 @@ class Conv1DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
|
||||
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
|
||||
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)
|
||||
|
||||
|
||||
class Conv2DLSTMCell(ConvLSTMCell):
|
||||
@ -2144,7 +2144,7 @@ class Conv2DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
|
||||
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
|
||||
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)
|
||||
|
||||
|
||||
class Conv3DLSTMCell(ConvLSTMCell):
|
||||
@ -2155,7 +2155,7 @@ class Conv3DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
|
||||
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)
|
||||
|
||||
|
||||
def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
||||
|
@ -222,6 +222,9 @@ class AttentionWrapperTest(test.TestCase):
|
||||
self.assertEqual(
|
||||
(None, batch_size, None),
|
||||
tuple(state_alignment_history.get_shape().as_list()))
|
||||
nest.assert_same_structure(
|
||||
cell.state_size,
|
||||
cell.zero_state(batch_size, dtypes.float32))
|
||||
# Remove the history from final_state for purposes of the
|
||||
# remainder of the tests.
|
||||
final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -70,6 +71,98 @@ class TestGatherTree(test.TestCase):
|
||||
|
||||
self.assertAllEqual(expected_result, res_)
|
||||
|
||||
def _test_gather_tree_from_array(self,
|
||||
depth_ndims=0,
|
||||
merged_batch_beam=False):
|
||||
array = np.array(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0]],
|
||||
[[2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 0]]]).transpose([1, 0, 2])
|
||||
parent_ids = np.array(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]],
|
||||
[[0, 0, 0], [1, 1, 0], [2, 0, 1], [0, 1, 0]]]).transpose([1, 0, 2])
|
||||
expected_array = np.array(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [0, 0, 0]],
|
||||
[[2, 3, 2], [7, 5, 7], [8, 9, 8], [11, 12, 0]]]).transpose([1, 0, 2])
|
||||
sequence_length = [[3, 3, 3], [4, 4, 3]]
|
||||
|
||||
array = ops.convert_to_tensor(
|
||||
array, dtype=dtypes.float32)
|
||||
parent_ids = ops.convert_to_tensor(
|
||||
parent_ids, dtype=dtypes.int32)
|
||||
expected_array = ops.convert_to_tensor(
|
||||
expected_array, dtype=dtypes.float32)
|
||||
|
||||
max_time = array_ops.shape(array)[0]
|
||||
batch_size = array_ops.shape(array)[1]
|
||||
beam_width = array_ops.shape(array)[2]
|
||||
|
||||
def _tile_in_depth(tensor):
|
||||
# Generate higher rank tensors by concatenating tensor and tensor + 1.
|
||||
for _ in range(depth_ndims):
|
||||
tensor = array_ops.stack([tensor, tensor + 1], -1)
|
||||
return tensor
|
||||
|
||||
if merged_batch_beam:
|
||||
array = array_ops.reshape(
|
||||
array, [max_time, batch_size * beam_width])
|
||||
expected_array = array_ops.reshape(
|
||||
expected_array, [max_time, batch_size * beam_width])
|
||||
|
||||
if depth_ndims > 0:
|
||||
array = _tile_in_depth(array)
|
||||
expected_array = _tile_in_depth(expected_array)
|
||||
|
||||
sorted_array = beam_search_decoder.gather_tree_from_array(
|
||||
array, parent_ids, sequence_length)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sorted_array = sess.run(sorted_array)
|
||||
expected_array = sess.run(expected_array)
|
||||
self.assertAllEqual(expected_array, sorted_array)
|
||||
|
||||
def test_gather_tree_from_array_scalar(self):
|
||||
self._test_gather_tree_from_array()
|
||||
|
||||
def test_gather_tree_from_array_1d(self):
|
||||
self._test_gather_tree_from_array(depth_ndims=1)
|
||||
|
||||
def test_gather_tree_from_array_1d_with_merged_batch_beam(self):
|
||||
self._test_gather_tree_from_array(depth_ndims=1, merged_batch_beam=True)
|
||||
|
||||
def test_gather_tree_from_array_2d(self):
|
||||
self._test_gather_tree_from_array(depth_ndims=2)
|
||||
|
||||
|
||||
class TestArrayShapeChecks(test.TestCase):
|
||||
|
||||
def _test_array_shape_dynamic_checks(self, static_shape, dynamic_shape,
|
||||
batch_size, beam_width, is_valid=True):
|
||||
t = array_ops.placeholder_with_default(
|
||||
np.random.randn(*static_shape).astype(np.float32),
|
||||
shape=dynamic_shape)
|
||||
|
||||
batch_size = array_ops.constant(batch_size)
|
||||
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
|
||||
|
||||
with self.test_session() as sess:
|
||||
if is_valid:
|
||||
sess.run(check_op)
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(check_op)
|
||||
|
||||
def test_array_shape_dynamic_checks(self):
|
||||
self._test_array_shape_dynamic_checks(
|
||||
(8, 4, 5, 10), (None, None, 5, 10), 4, 5, is_valid=True)
|
||||
self._test_array_shape_dynamic_checks(
|
||||
(8, 20, 10), (None, None, 10), 4, 5, is_valid=True)
|
||||
self._test_array_shape_dynamic_checks(
|
||||
(8, 21, 10), (None, None, 10), 4, 5, is_valid=False)
|
||||
self._test_array_shape_dynamic_checks(
|
||||
(8, 4, 6, 10), (None, None, None, 10), 4, 5, is_valid=False)
|
||||
self._test_array_shape_dynamic_checks(
|
||||
(8, 4), (None, None), 4, 5, is_valid=False)
|
||||
|
||||
|
||||
class TestEosMasking(test.TestCase):
|
||||
"""Tests EOS masking used in beam search."""
|
||||
@ -319,7 +412,8 @@ class TestLargeBeamStep(test.TestCase):
|
||||
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention,
|
||||
with_alignment_history=False):
|
||||
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||
batch_size = 5
|
||||
@ -359,7 +453,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
cell=cell,
|
||||
attention_mechanism=attention_mechanism,
|
||||
attention_layer_size=attention_depth,
|
||||
alignment_history=False)
|
||||
alignment_history=with_alignment_history)
|
||||
cell_state = cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||
if has_attention:
|
||||
@ -420,6 +514,12 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
def testDynamicDecodeRNNBatchMajorYesAttention(self):
|
||||
self._testDynamicDecodeRNN(time_major=False, has_attention=True)
|
||||
|
||||
def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
|
||||
self._testDynamicDecodeRNN(
|
||||
time_major=False,
|
||||
has_attention=True,
|
||||
with_alignment_history=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -1278,7 +1278,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
attention_state=self._item_or_tuple(
|
||||
a.state_size for a in self._attention_mechanisms),
|
||||
alignment_history=self._item_or_tuple(
|
||||
() for _ in self._attention_mechanisms)) # sometimes a TensorArray
|
||||
a.alignments_size if self._alignment_history else ()
|
||||
for a in self._attention_mechanisms)) # sometimes a TensorArray
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
"""Return an initial (zero) state tuple for this `AttentionWrapper`.
|
||||
@ -1318,22 +1319,26 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
cell_state = nest.map_structure(
|
||||
lambda s: array_ops.identity(s, name="checked_cell_state"),
|
||||
cell_state)
|
||||
initial_alignments = [
|
||||
attention_mechanism.initial_alignments(batch_size, dtype)
|
||||
for attention_mechanism in self._attention_mechanisms]
|
||||
return AttentionWrapperState(
|
||||
cell_state=cell_state,
|
||||
time=array_ops.zeros([], dtype=dtypes.int32),
|
||||
attention=_zero_state_tensors(self._attention_layer_size, batch_size,
|
||||
dtype),
|
||||
alignments=self._item_or_tuple(
|
||||
attention_mechanism.initial_alignments(batch_size, dtype)
|
||||
for attention_mechanism in self._attention_mechanisms),
|
||||
alignments=self._item_or_tuple(initial_alignments),
|
||||
attention_state=self._item_or_tuple(
|
||||
attention_mechanism.initial_state(batch_size, dtype)
|
||||
for attention_mechanism in self._attention_mechanisms),
|
||||
alignment_history=self._item_or_tuple(
|
||||
tensor_array_ops.TensorArray(dtype=dtype, size=0,
|
||||
dynamic_size=True)
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype,
|
||||
size=0,
|
||||
dynamic_size=True,
|
||||
element_shape=alignment.shape)
|
||||
if self._alignment_history else ()
|
||||
for _ in self._attention_mechanisms))
|
||||
for alignment in initial_alignments))
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Perform a step of attention-wrapped RNN.
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
__all__ = [
|
||||
@ -121,14 +122,114 @@ def tile_batch(t, multiplier, name=None):
|
||||
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
||||
|
||||
|
||||
def gather_tree_from_array(t, parent_ids, sequence_length):
|
||||
"""Calculates the full beams for `TensorArray`s.
|
||||
|
||||
Args:
|
||||
t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
|
||||
shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
|
||||
where `s` is the depth shape.
|
||||
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
|
||||
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is a stacked `TensorArray` of the same size and type as
|
||||
`t` and where beams are sorted in each `Tensor` according to `parent_ids`.
|
||||
"""
|
||||
max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0]
|
||||
batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1]
|
||||
beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2]
|
||||
|
||||
# Generate beam ids that will be reordered by gather_tree.
|
||||
beam_ids = array_ops.expand_dims(
|
||||
array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
|
||||
beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])
|
||||
|
||||
mask = array_ops.sequence_mask(
|
||||
sequence_length, maxlen=max_time, dtype=dtypes.int32)
|
||||
mask = array_ops.transpose(mask, perm=[2, 0, 1])
|
||||
|
||||
# Use beam_width + 1 to mark the end of beam.
|
||||
masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)
|
||||
|
||||
max_sequence_lengths = math_ops.to_int32(
|
||||
math_ops.reduce_max(sequence_length, axis=1))
|
||||
sorted_beam_ids = beam_search_ops.gather_tree(
|
||||
step_ids=masked_beam_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=beam_width + 1)
|
||||
|
||||
# For out of range steps, simply copy the same beam.
|
||||
sorted_beam_ids = array_ops.where(
|
||||
math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids)
|
||||
|
||||
# Generate indices for gather_nd.
|
||||
time_ind = array_ops.tile(array_ops.reshape(
|
||||
math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
|
||||
batch_ind = array_ops.tile(array_ops.reshape(
|
||||
math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
|
||||
batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
|
||||
indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)
|
||||
|
||||
# Gather from a tensor with collapsed additional dimensions.
|
||||
gather_from = t
|
||||
final_shape = array_ops.shape(gather_from)
|
||||
gather_from = array_ops.reshape(
|
||||
gather_from, [max_time, batch_size, beam_width, -1])
|
||||
ordered = array_ops.gather_nd(gather_from, indices)
|
||||
ordered = array_ops.reshape(ordered, final_shape)
|
||||
|
||||
return ordered
|
||||
|
||||
|
||||
def _check_maybe(t):
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
raise TypeError(
|
||||
"TensorArray state is not supported by BeamSearchDecoder: %s" % t.name)
|
||||
if t.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||
|
||||
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
||||
"""Raises an exception if dimensions are known statically and can not be
|
||||
reshaped to [batch_size, beam_size, -1].
|
||||
"""
|
||||
reshaped_shape = tensor_shape.TensorShape([batch_size, beam_width, None])
|
||||
if (batch_size is not None and shape[0].value is not None
|
||||
and (shape[0] != batch_size * beam_width
|
||||
or (shape.ndims >= 2 and shape[1].value is not None
|
||||
and (shape[0] != batch_size or shape[1] != beam_width)))):
|
||||
tf_logging.warn("TensorArray reordering expects elements to be "
|
||||
"reshapable to %s which is incompatible with the "
|
||||
"current shape %s. Consider setting "
|
||||
"reorder_tensor_arrays to False to disable TensorArray "
|
||||
"reordering during the beam search."
|
||||
% (reshaped_shape, shape))
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_batch_beam(t, batch_size, beam_width):
|
||||
"""Returns an Assert operation checking that the elements of the stacked
|
||||
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
|
||||
the TensorArray elements have a known rank of at least 1.
|
||||
"""
|
||||
error_message = ("TensorArray reordering expects elements to be "
|
||||
"reshapable to [batch_size, beam_size, -1] which is "
|
||||
"incompatible with the dynamic shape of %s elements. "
|
||||
"Consider setting reorder_tensor_arrays to False to disable "
|
||||
"TensorArray reordering during the beam search."
|
||||
% (t.name))
|
||||
rank = t.shape.ndims
|
||||
shape = array_ops.shape(t)
|
||||
if rank == 2:
|
||||
condition = math_ops.equal(shape[1], batch_size * beam_width)
|
||||
else:
|
||||
condition = math_ops.logical_or(
|
||||
math_ops.equal(shape[1], batch_size * beam_width),
|
||||
math_ops.logical_and(
|
||||
math_ops.equal(shape[1], batch_size),
|
||||
math_ops.equal(shape[2], beam_width)))
|
||||
return control_flow_ops.Assert(condition, [error_message])
|
||||
|
||||
|
||||
|
||||
class BeamSearchDecoder(decoder.Decoder):
|
||||
"""BeamSearch sampling decoder.
|
||||
@ -173,7 +274,8 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
initial_state,
|
||||
beam_width,
|
||||
output_layer=None,
|
||||
length_penalty_weight=0.0):
|
||||
length_penalty_weight=0.0,
|
||||
reorder_tensor_arrays=True):
|
||||
"""Initialize the BeamSearchDecoder.
|
||||
|
||||
Args:
|
||||
@ -188,6 +290,12 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||||
to storing the result or sampling.
|
||||
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||||
state will be reordered according to the beam search path. If the
|
||||
`TensorArray` can be reordered, the stacked form will be returned.
|
||||
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||
to reordering.
|
||||
|
||||
Raises:
|
||||
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||
@ -202,6 +310,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"output_layer must be a Layer, received: %s" % type(output_layer))
|
||||
self._cell = cell
|
||||
self._output_layer = output_layer
|
||||
self._reorder_tensor_arrays = reorder_tensor_arrays
|
||||
|
||||
if callable(embedding):
|
||||
self._embedding_fn = embedding
|
||||
@ -299,12 +408,13 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"""
|
||||
finished, start_inputs = self._finished, self._start_inputs
|
||||
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=0.0,
|
||||
off_value=-np.Inf,
|
||||
dtype=nest.flatten(self._initial_cell_state)[0].dtype)
|
||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||
dtype=dtype)
|
||||
|
||||
initial_state = BeamSearchDecoderState(
|
||||
cell_state=self._initial_cell_state,
|
||||
@ -341,6 +451,11 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
outputs.parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=self._end_token)
|
||||
if self._reorder_tensor_arrays:
|
||||
final_state = final_state._replace(cell_state=nest.map_structure(
|
||||
lambda t: self._maybe_sort_array_beams(
|
||||
t, outputs.parent_ids, final_state.lengths),
|
||||
final_state.cell_state))
|
||||
outputs = FinalBeamSearchDecoderOutput(
|
||||
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
|
||||
return outputs, final_state
|
||||
@ -431,9 +546,10 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
returned unchanged.
|
||||
|
||||
Raises:
|
||||
TypeError: If `t` is an instance of `TensorArray`.
|
||||
ValueError: If the rank of `t` is not statically known.
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 1:
|
||||
return self._split_batch_beams(t, s)
|
||||
@ -454,15 +570,55 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
A reshaped version of t with shape `[batch_size, beam_width] + s`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `t` is an instance of `TensorArray`.
|
||||
ValueError: If the rank of `t` is not statically known.
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 2:
|
||||
return self._merge_batch_beams(t, s)
|
||||
else:
|
||||
return t
|
||||
|
||||
def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
|
||||
"""Maybe sorts beams within a `TensorArray`.
|
||||
|
||||
Args:
|
||||
t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape
|
||||
`[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where
|
||||
`s` is the depth shape.
|
||||
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
|
||||
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
|
||||
|
||||
Returns:
|
||||
A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if
|
||||
it is not a `TensorArray` or does not meet shape requirements.
|
||||
"""
|
||||
if not isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
# pylint: disable=protected-access
|
||||
if (not t._infer_shape or not t._element_shape
|
||||
or t._element_shape[0].ndims is None
|
||||
or t._element_shape[0].ndims < 1):
|
||||
shape = (
|
||||
t._element_shape[0] if t._infer_shape and t._element_shape
|
||||
else tensor_shape.TensorShape(None))
|
||||
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
|
||||
"sorting based on the beam search result. For a "
|
||||
"TensorArray to be sorted, its elements shape must be "
|
||||
"defined and have at least a rank of 1, but saw shape: %s"
|
||||
% (t.handle.name, shape))
|
||||
return t
|
||||
shape = t._element_shape[0]
|
||||
# pylint: enable=protected-access
|
||||
if not _check_static_batch_beam_maybe(
|
||||
shape, tensor_util.constant_value(self._batch_size), self._beam_width):
|
||||
return t
|
||||
t = t.stack()
|
||||
with ops.control_dependencies(
|
||||
[_check_batch_beam(t, self._batch_size, self._beam_width)]):
|
||||
return gather_tree_from_array(t, parent_ids, sequence_length)
|
||||
|
||||
def step(self, time, inputs, state, name=None):
|
||||
"""Perform a decoding step.
|
||||
|
||||
@ -757,6 +913,8 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
|
||||
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
|
||||
or the original tensor if its dimensions are too small.
|
||||
"""
|
||||
if isinstance(gather_from, tensor_array_ops.TensorArray):
|
||||
return gather_from
|
||||
_check_maybe(gather_from)
|
||||
if gather_from.shape.ndims >= len(gather_shape):
|
||||
return _tensor_gather_helper(
|
||||
|
@ -94,7 +94,7 @@ of thin wrapper functions in
|
||||
[variables.py](https://www.tensorflow.org/code/tensorflow/contrib/framework/python/ops/variables.py)
|
||||
which allow callers to easily define variables.
|
||||
|
||||
For example, to create a `weight` variable, initialize it using a truncated
|
||||
For example, to create a `weights` variable, initialize it using a truncated
|
||||
normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
|
||||
one need only declare the following:
|
||||
|
||||
|
@ -33,7 +33,7 @@ def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"):
|
||||
r"""Conjugate gradient least squares solver.
|
||||
|
||||
Solves a linear least squares problem \\(||A x - rhs||_2\\) for a single
|
||||
righ-hand side, using an iterative, matrix-free algorithm where the action of
|
||||
right-hand side, using an iterative, matrix-free algorithm where the action of
|
||||
the matrix A is represented by `operator`. The CGLS algorithm implicitly
|
||||
applies the symmetric conjugate gradient algorithm to the normal equations
|
||||
\\(A^* A x = A^* rhs\\). The iteration terminates when either
|
||||
|
@ -41,7 +41,7 @@ def conjugate_gradient(operator,
|
||||
r"""Conjugate gradient solver.
|
||||
|
||||
Solves a linear system of equations `A*x = rhs` for selfadjoint, positive
|
||||
definite matrix `A` and righ-hand side vector `rhs`, using an iterative,
|
||||
definite matrix `A` and right-hand side vector `rhs`, using an iterative,
|
||||
matrix-free algorithm where the action of the matrix A is represented by
|
||||
`operator`. The iteration terminates when either the number of iterations
|
||||
exceeds `max_iter` or when the residual norm has been reduced to `tol`
|
||||
|
@ -83,6 +83,7 @@ cc_library(
|
||||
"kernels/trt_engine_op.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":trt_logging",
|
||||
":trt_resources",
|
||||
@ -154,6 +155,7 @@ py_library(
|
||||
deps = [
|
||||
":trt_convert_py",
|
||||
":trt_ops_py",
|
||||
"//tensorflow/python:errors",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -2,7 +2,8 @@ Using TensorRT in TensorFlow
|
||||
============================
|
||||
|
||||
This module provides necessary bindings and introduces TRT_engine_op
|
||||
operator that wraps a subgraph in TensorRT.
|
||||
operator that wraps a subgraph in TensorRT. This is still a work in progress
|
||||
but should be useable with most common graphs.
|
||||
|
||||
Compilation
|
||||
-----------
|
||||
@ -15,26 +16,10 @@ configure script should find the necessary components from the system
|
||||
automatically. If installed from tar packages, user has to set path to
|
||||
location where the library is installed during configuration.
|
||||
|
||||
|
||||
```
|
||||
```shell
|
||||
bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
|
||||
```
|
||||
|
||||
After the installation of tensorflow package, TensorRT transformation
|
||||
will be available. An example use is shown below.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.tensorrt as trt
|
||||
#... create and train or load model
|
||||
gdef = sess.graph.as_graph_def()
|
||||
trt_gdef = trt.create_inference_graph(
|
||||
gdef, #original graph_def
|
||||
["output"], #name of output node(s)
|
||||
max_batch_size, #maximum batch size to run the inference
|
||||
max_workspace_size_bytes) # max memory for TensorRT to use
|
||||
tf.reset_default_graph()
|
||||
tf.import_graph_def(graph_def=trt_gdef)
|
||||
#...... run inference
|
||||
```
|
||||
will be available. An example use can be found in test/test_tftrt.py directory
|
||||
|
@ -18,6 +18,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.tensorrt.python import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,g-import-not-at-top
|
||||
try:
|
||||
from tensorflow.contrib.tensorrt.python import *
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
'**** Failed to initialize TensorRT. This is either because the TensorRT'
|
||||
' installation path is not in LD_LIBRARY_PATH, or because you do not have'
|
||||
' it installed. If not installed, please go to'
|
||||
' https://developer.nvidia.com/tensorrt to download and install'
|
||||
' TensorRT ****')
|
||||
print(no_trt_message)
|
||||
raise e
|
||||
# pylint: enable=unused-import,wildcard-import,g-import-not-at-top
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
@ -48,13 +49,29 @@ namespace tensorrt {
|
||||
namespace convert {
|
||||
namespace {
|
||||
|
||||
static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||
bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||
// LINT.IfChange
|
||||
// TODO(jie): Segmentation shouldn't associated with op name.
|
||||
// Split it into a registration for each kernel.
|
||||
static const std::set<string> candidate_ops = {
|
||||
"Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
|
||||
"Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
|
||||
"Identity",
|
||||
"Const",
|
||||
"Conv2D",
|
||||
"MaxPool",
|
||||
"BiasAdd",
|
||||
"Relu",
|
||||
"Add",
|
||||
"Mul",
|
||||
"Sub",
|
||||
"Rsqrt",
|
||||
"Pad",
|
||||
"Mean",
|
||||
"AvgPool",
|
||||
"ConcatV2",
|
||||
"DepthwiseConv2dNative",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
// TODO(ben,jie): ...
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
|
||||
return candidate_ops.count(node_def.op());
|
||||
@ -69,6 +86,8 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
|
||||
if (!subgraph_node_ids.count(edge->src()->id()) &&
|
||||
!edge->src()->IsSource()) {
|
||||
incoming_edges->insert(edge);
|
||||
} else {
|
||||
VLOG(2) << edge->src()->name() << " N, ";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -82,7 +101,10 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
|
||||
for (const tensorflow::Edge* edge : node->out_edges()) {
|
||||
if (!subgraph_node_ids.count(edge->dst()->id()) &&
|
||||
!edge->dst()->IsSink()) {
|
||||
VLOG(2) << edge->dst()->name() << " Y, ";
|
||||
outgoing_edges->insert(edge);
|
||||
} else {
|
||||
VLOG(2) << edge->dst()->name() << " N, ";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,74 +131,150 @@ std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRT(
|
||||
const std::vector<string>& output_names,
|
||||
const std::set<int>& subgraph_node_ids,
|
||||
size_t max_batch_size, // Max batch size that engine will be created for
|
||||
// Max amount of memory that engine will be allowed to consume, in bytes
|
||||
size_t max_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& graph_properties,
|
||||
tensorflow::Graph* graph) {
|
||||
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||
GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
|
||||
|
||||
// TODO(sami): convert references to pointers
|
||||
struct ConvertGraphParams {
|
||||
ConvertGraphParams(
|
||||
tensorflow::Graph& inp_graph,
|
||||
const std::vector<string>& output_node_names,
|
||||
const std::set<int>& subgraph_node_id_numbers,
|
||||
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& current_graph_properties,
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edges,
|
||||
int engine_precision_mode)
|
||||
: graph(inp_graph),
|
||||
output_names(output_node_names),
|
||||
subgraph_node_ids(subgraph_node_id_numbers),
|
||||
max_batch_size(max_supported_batch_size),
|
||||
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
|
||||
graph_properties(current_graph_properties),
|
||||
output_edge_map(output_edges),
|
||||
precision_mode(engine_precision_mode) {}
|
||||
tensorflow::Graph& graph;
|
||||
const std::vector<string>& output_names;
|
||||
const std::set<int>& subgraph_node_ids;
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
const tensorflow::grappler::GraphProperties& graph_properties;
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
|
||||
int precision_mode;
|
||||
std::vector<std::pair<int, int>> subgraph_inputs;
|
||||
std::vector<std::pair<int, int>> subgraph_outputs;
|
||||
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||
};
|
||||
|
||||
// Collect inputs by looking for incoming edges
|
||||
for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
|
||||
subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||
static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
|
||||
GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids,
|
||||
&p->subgraph_incoming_edges);
|
||||
for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) {
|
||||
p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
auto output_name_to_index_map = BuildTensorNameMap(p->output_names);
|
||||
std::set<std::pair<int, int>> subgraph_outputs_set;
|
||||
// Collect outputs referenced from output_names
|
||||
auto output_name_to_index_map = BuildTensorNameMap(output_names);
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph->FindNodeId(node_id);
|
||||
for (int node_id : p->subgraph_node_ids) {
|
||||
tensorflow::Node* node = p->graph.FindNodeId(node_id);
|
||||
if (output_name_to_index_map.count(node->name())) {
|
||||
for (int index : output_name_to_index_map.at(node->name())) {
|
||||
subgraph_outputs_set.insert({node_id, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect outputs referenced from outgoing edges
|
||||
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||
GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
|
||||
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
|
||||
GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids,
|
||||
&p->subgraph_outgoing_edges);
|
||||
for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) {
|
||||
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
// Impose an ordering on the outputs
|
||||
std::vector<std::pair<int, int>> subgraph_outputs(
|
||||
subgraph_outputs_set.begin(), subgraph_outputs_set.end());
|
||||
// Build TensorRT node and add it to the graph
|
||||
p->subgraph_outputs.reserve(subgraph_outputs_set.size());
|
||||
p->subgraph_outputs.insert(p->subgraph_outputs.begin(),
|
||||
subgraph_outputs_set.begin(),
|
||||
subgraph_outputs_set.end());
|
||||
return tensorflow::Status::OK();
|
||||
};
|
||||
|
||||
tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
|
||||
TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
|
||||
tensorflow::NodeDef trt_node_def;
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
|
||||
*graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
|
||||
max_batch_size, max_workspace_size_bytes, graph_properties,
|
||||
&trt_node_def));
|
||||
SubGraphParams s(params->graph, params->subgraph_node_ids,
|
||||
params->subgraph_inputs, params->subgraph_outputs,
|
||||
params->max_batch_size, params->max_workspace_size_bytes,
|
||||
params->graph_properties, params->output_edge_map,
|
||||
&trt_node_def, params->precision_mode);
|
||||
TF_RETURN_IF_ERROR(InjectCalibrationNode(s));
|
||||
tensorflow::Status status;
|
||||
tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
|
||||
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
for (auto in_edge :
|
||||
params->subgraph_incoming_edges) { // loop over incoming edges and
|
||||
// attach them to calib node
|
||||
// tensorflow::Node* src_node = in_edge->src();
|
||||
auto src_output = in_edge->src_output();
|
||||
auto dst_node = in_edge->dst();
|
||||
auto dst_input = in_edge->dst_input();
|
||||
VLOG(1) << " update edge " << trt_node->name() << ":" << src_output
|
||||
<< " -> " << dst_node->name() << ":" << dst_input;
|
||||
TF_RETURN_IF_ERROR(
|
||||
params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
|
||||
TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
|
||||
tensorflow::NodeDef trt_node_def;
|
||||
|
||||
SubGraphParams s(params->graph, params->subgraph_node_ids,
|
||||
params->subgraph_inputs, params->subgraph_outputs,
|
||||
params->max_batch_size, params->max_workspace_size_bytes,
|
||||
params->graph_properties, params->output_edge_map,
|
||||
&trt_node_def, params->precision_mode);
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s));
|
||||
tensorflow::Status status;
|
||||
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
|
||||
|
||||
// AddNode does not wire edges.
|
||||
// Re-map incoming edges to use the new TRT node instead of the orig subgraph
|
||||
std::map<std::pair<int, int>, int> subgraph_edge_to_input_map;
|
||||
for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
|
||||
subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
|
||||
}
|
||||
for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
|
||||
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||
int new_src_output = subgraph_edge_to_input_map.at(old_src);
|
||||
params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
|
||||
new_src_output);
|
||||
params->graph.RemoveEdge(edge);
|
||||
}
|
||||
|
||||
VLOG(2) << "new wiring edges: " << trt_node->in_edges().size();
|
||||
for (const tensorflow::Edge* edge : trt_node->in_edges()) {
|
||||
VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
|
||||
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
|
||||
for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
|
||||
subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
|
||||
for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) {
|
||||
subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i});
|
||||
}
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
|
||||
for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) {
|
||||
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||
int new_src_output = subgraph_edge_to_output_map.at(old_src);
|
||||
TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
|
||||
edge->dst_input()));
|
||||
TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
|
||||
trt_node, new_src_output, edge->dst(), edge->dst_input()));
|
||||
}
|
||||
// Remove the original subgraph
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph->FindNodeId(node_id);
|
||||
for (int node_id : params->subgraph_node_ids) {
|
||||
tensorflow::Node* node = params->graph.FindNodeId(node_id);
|
||||
// Don't remove the input placeholders
|
||||
if (node->type_string() == "Placeholder") {
|
||||
continue;
|
||||
}
|
||||
graph->RemoveNode(node);
|
||||
params->graph.RemoveNode(node);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -194,12 +292,39 @@ tensorflow::Status BuildNodeMap(
|
||||
}
|
||||
|
||||
} // namespace
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) {
|
||||
VLOG(0) << "Starting Calib Conversion";
|
||||
tensorflow::Graph graph(tensorflow::OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||
tensorflow::GraphConstructorOptions(), graph_def, &graph));
|
||||
// get calib nodes
|
||||
std::vector<tensorflow::Node*> calib_nodes;
|
||||
for (auto node : graph.op_nodes()) {
|
||||
if (node->type_string() == "TRTCalibOp") {
|
||||
VLOG(1) << "Found Calib Node";
|
||||
calib_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size();
|
||||
if (calib_nodes.size() == 0)
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Graph doesn't contain any calibration nodes!."
|
||||
" Please generate calibration graph and run calibration first");
|
||||
for (auto n : calib_nodes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n));
|
||||
}
|
||||
graph.ToGraphDef(infer_graph);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
|
||||
// Optimization pass
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
|
||||
int precision_mode = FP32MODE, int minimum_segment_size = 3) {
|
||||
// optimization pass
|
||||
tensorflow::grappler::GrapplerItem item;
|
||||
item.fetch = output_names;
|
||||
tensorflow::GraphDef gdef;
|
||||
@ -209,16 +334,23 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
tensorflow::grappler::LayoutOptimizer optimizer;
|
||||
tensorflow::grappler::Cluster* cluster;
|
||||
|
||||
// Virtual cluster
|
||||
// virtual cluster
|
||||
tensorflow::DeviceProperties device_properties;
|
||||
|
||||
device_properties.set_type("GPU");
|
||||
device_properties.mutable_environment()->insert({"architecture", "6"});
|
||||
cluster =
|
||||
new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
|
||||
|
||||
// single machine
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
VLOG(2) << "cpu_cores: " << num_cpu_cores;
|
||||
VLOG(2) << "gpus: " << num_gpus;
|
||||
|
||||
TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
|
||||
|
||||
// Constant folding
|
||||
// constant folding
|
||||
item.graph = gdef;
|
||||
tensorflow::grappler::ConstantFolding fold(nullptr);
|
||||
TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
|
||||
@ -226,7 +358,6 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
// AJ refactoring shape inference through grappler/GraphProperties.
|
||||
tensorflow::grappler::GraphProperties static_graph_properties(item);
|
||||
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
|
||||
|
||||
// Build full graph
|
||||
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||
gdef.library());
|
||||
@ -243,7 +374,7 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
}
|
||||
|
||||
// TODO(sami): this should be passed as a knob!!!!
|
||||
segment_options.minimum_segment_size = 2;
|
||||
segment_options.minimum_segment_size = minimum_segment_size;
|
||||
tensorflow::tensorrt::segment::SegmentNodesVector segments;
|
||||
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
|
||||
gdef, IsTensorRTCandidate, segment_options, &segments));
|
||||
@ -252,14 +383,37 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
}
|
||||
std::unordered_map<string, tensorflow::Node*> node_map;
|
||||
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
|
||||
std::unordered_map<string, std::pair<int, string>> output_edge_map;
|
||||
int count = 0;
|
||||
float total_num_nodes_in_segments = 0.;
|
||||
for (auto s : segments) {
|
||||
total_num_nodes_in_segments += s.size();
|
||||
}
|
||||
for (const std::set<string>& subgraph_node_names : segments) {
|
||||
std::set<int> subgraph_node_ids;
|
||||
size_t max_mem_per_engine =
|
||||
max_workspace_size_bytes *
|
||||
((float)subgraph_node_names.size() / total_num_nodes_in_segments);
|
||||
std::stringstream oss;
|
||||
for (const string& node_name : subgraph_node_names) {
|
||||
oss << " " << node_name;
|
||||
subgraph_node_ids.insert(node_map.at(node_name)->id());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
|
||||
output_names, subgraph_node_ids, max_batch_size,
|
||||
max_workspace_size_bytes, static_graph_properties, &graph));
|
||||
VLOG(2) << "Subgraph nodes" << oss.str();
|
||||
ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size,
|
||||
max_mem_per_engine, static_graph_properties,
|
||||
&output_edge_map, precision_mode);
|
||||
if (precision_mode == INT8MODE) {
|
||||
TF_RETURN_IF_ERROR(GetCalibNode(&p));
|
||||
} else {
|
||||
tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
|
||||
<< " due to: \n"
|
||||
<< status.ToString() << " SKIPPING......";
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
graph.ToGraphDef(new_graph_def);
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -28,14 +28,20 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// This method converts an already generated calibration graph which was used in
|
||||
// calibration runs to an inference graph
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def);
|
||||
|
||||
// max_batch_size: maximum batch size which can be used for inference for
|
||||
// optimization targets inference run with max batch size.
|
||||
// max_workspace_size_bytes: The upper bound of memory allowence for
|
||||
// max_workspace_size_bytes: The upper bound of memory allowance for
|
||||
// engine building.
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def);
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
|
||||
int precision_mode, int minimum_segment_size);
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -32,16 +34,49 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||
const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
|
||||
const std::vector<std::pair<int, int>>&
|
||||
input_inds, // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>&
|
||||
output_inds, // {node_id, output_idx}
|
||||
size_t max_batch_size, size_t max_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& graph_prop,
|
||||
tensorflow::NodeDef* trt_node);
|
||||
const int FP32MODE = 0;
|
||||
const int FP16MODE = 1;
|
||||
const int INT8MODE = 2;
|
||||
|
||||
struct SubGraphParams {
|
||||
SubGraphParams(
|
||||
tensorflow::Graph& inp_graph,
|
||||
const std::set<int>& subgraph_node_id_numbers,
|
||||
const std::vector<std::pair<int, int>>& input_indices,
|
||||
const std::vector<std::pair<int, int>>& output_indices,
|
||||
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& current_graph_properties,
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edges,
|
||||
tensorflow::NodeDef* constructed_trt_node,
|
||||
int engine_precision_mode = FP32MODE)
|
||||
: graph(inp_graph),
|
||||
subgraph_node_ids(subgraph_node_id_numbers),
|
||||
input_inds(input_indices),
|
||||
output_inds(output_indices),
|
||||
max_batch_size(max_supported_batch_size),
|
||||
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
|
||||
graph_properties(current_graph_properties),
|
||||
output_edge_map(output_edges),
|
||||
trt_node(constructed_trt_node),
|
||||
precision_mode(engine_precision_mode) {}
|
||||
|
||||
tensorflow::Graph& graph;
|
||||
const std::set<int>& subgraph_node_ids;
|
||||
const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
const tensorflow::grappler::GraphProperties& graph_properties;
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
|
||||
tensorflow::NodeDef* trt_node;
|
||||
const int precision_mode;
|
||||
};
|
||||
|
||||
// TODO(sami): Replace references with const reference or pointers
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
|
||||
tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
|
||||
tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
|
||||
tensorflow::Node* c_node);
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -21,10 +21,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -113,7 +114,13 @@ void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
|
||||
ctx->set_output(i, t);
|
||||
}
|
||||
VLOG(2) << "Filled map for sending";
|
||||
calib_res->calibrator_->setBatch(input_data);
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
calib_res->calibrator_->setBatch(input_data, *stream);
|
||||
VLOG(2) << "Passed calibration data";
|
||||
// TODO(aaroey): make sure we wait for the completion of calibration on the
|
||||
// last batch in future PR.
|
||||
|
@ -24,8 +24,12 @@ limitations under the License.
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
static ::tensorflow::tensorrt::Logger logger;
|
||||
namespace gpu = ::perftools::gputools;
|
||||
using IRuntime = nvinfer1::IRuntime;
|
||||
using Dims = nvinfer1::Dims;
|
||||
|
||||
namespace tensorrt {
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// read serialized_engine
|
||||
@ -40,10 +44,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
// TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
|
||||
// gpu where the input/output is also located.
|
||||
int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
|
||||
cudaSetDevice(gpu_id);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
if (gpu_id != device) LOG(FATAL) << "set device failed!";
|
||||
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
|
||||
IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
|
||||
serialized_engine.c_str(), serialized_engine.size(), nullptr));
|
||||
|
||||
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
|
||||
// Runtime is safe to delete after engine creation
|
||||
infer->destroy();
|
||||
@ -55,7 +70,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
|
||||
size_t binding_index;
|
||||
int num_batch = 0;
|
||||
bool valid = true;
|
||||
for (int i = 0; i < context->num_inputs(); i++) {
|
||||
// Grab the input tensor
|
||||
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
|
||||
@ -64,8 +78,12 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
if (i == 0) {
|
||||
num_batch = input_shape.dim_size(0);
|
||||
if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
|
||||
LOG(FATAL) << "input tensor batch larger than max_batch_size: "
|
||||
<< trt_engine_ptr_->getMaxBatchSize();
|
||||
}
|
||||
} else if (num_batch != input_shape.dim_size(0)) {
|
||||
valid = false;
|
||||
LOG(FATAL) << "input data inconsistent batch size";
|
||||
break;
|
||||
}
|
||||
switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
|
||||
@ -81,9 +99,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Might want a different way to inform the user of batch size inconsistency
|
||||
if (!valid) LOG(WARNING) << "input data inconsistent batch size";
|
||||
|
||||
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
|
||||
// This is bad that we have to reallocate output buffer every run.
|
||||
// Create an output tensor
|
||||
@ -126,9 +141,11 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
|
||||
// execution handled by TF since we are getting stream from TF.
|
||||
// it is safe for CPU pointer array (buffers) to go out of scope after enqueue
|
||||
trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
|
||||
// TODO(jie): trt enqueue does not return error
|
||||
auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
|
||||
*stream, nullptr);
|
||||
VLOG(2) << "enqueue returns: " << ret;
|
||||
// sync should be done by TF.
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user