Merge changes from github.
PiperOrigin-RevId: 189231636
This commit is contained in:
parent
61032e9ca7
commit
ccd8079e57
README.mdSECURITY.mdconfigure.py
tensorflow
compiler/xla
contrib
BUILD__init__.py
boosted_trees/lib/utils
cmake
data
distributions
eager/python
factorization
ffmpeg/default
gan
kafka
kfac/python/kernel_tests
labeled_tensor
layers
learn
lite
Makefilearena_planner.hbuild_rpi_lib.shbuiltin_ops.herror_reporter.h
g3doc
interpreter.hinterpreter_test.cckernels
conv.ccdepthwise_conv.ccfully_connected.cckernel_util.hlsh_projection.cclstm.ccreshape.ccreshape_test.cctest_util.ccunidirectional_sequence_lstm.cc
memory_planner.hmodel.hnnapi
rpi_makefile.incschema/builtin_ops_header
simple_memory_arena.ccsimple_memory_arena.hlookup
makefile
mpi
predictor
py2tf
quantize/python
fold_batch_norms.pyquant_ops.pyquantize.pyquantize_graph.pyquantize_parameterized_test.pyquantize_test.py
remote_fused_graph/pylib
rnn/python/ops
saved_model
seq2seq/python/ops
session_bundle
slim/python/slim/data
tensor_forest
tensorboard
tensorrt
timeseries
@ -22,6 +22,10 @@ organization for the purposes of conducting machine learning and deep neural
|
||||
networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
|
||||
## Installation
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
|
||||
|
||||
|
14
SECURITY.md
14
SECURITY.md
@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow.
|
||||
|
||||
## TensorFlow models are programs
|
||||
|
||||
TensorFlow's runtime system interprets and executes programs. What machine
|
||||
TensorFlow's runtime system interprets and executes programs. What machine
|
||||
learning practitioners term
|
||||
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
|
||||
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
|
||||
@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run
|
||||
inference on the data.
|
||||
|
||||
**TensorFlow models are programs, and need to be treated as such from a security
|
||||
perspective.**
|
||||
perspective.**
|
||||
|
||||
## Running untrusted models
|
||||
|
||||
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
|
||||
[nsjail](https://github.com/google/nsjail)).
|
||||
[nsjail](https://github.com/google/nsjail)).
|
||||
|
||||
There are several ways in which a model could become untrusted. Obviously, if an
|
||||
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
|
||||
@ -109,11 +109,11 @@ graphs known to the `ModelServer`. This means that an attacker may run
|
||||
graphs using untrusted inputs as described above, but they would not be able to
|
||||
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
|
||||
directly to an untrusted network, **but only if the graphs it is configured to
|
||||
use have been carefully audited to be safe**.
|
||||
use have been carefully audited to be safe**.
|
||||
|
||||
Similar to best practices for other servers, we recommend running any
|
||||
`ModelServer` with appropriate privileges (i.e., using a separate user with
|
||||
reduced permisisons). In the spirit of defense in depth, we recommend
|
||||
reduced permissions). In the spirit of defense in depth, we recommend
|
||||
authenticating requests to any TensorFlow server connected to an untrusted
|
||||
network, as well as sandboxing the server to minimize the adverse effects of
|
||||
any breach.
|
||||
@ -133,7 +133,7 @@ which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
|
||||
can perform arbitrary computations means that they may read and write files,
|
||||
communicate via the network, produce deadlocks and infinite loops, or run out
|
||||
of memory. It is only when these behaviors are outside the specifications of the
|
||||
operations involved that such behavior is a vulnerability.
|
||||
operations involved that such behavior is a vulnerability.
|
||||
|
||||
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
|
||||
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
|
||||
@ -168,7 +168,7 @@ below).
|
||||
|
||||
Please use a descriptive subject line for your report email. After the initial
|
||||
reply to your report, the security team will endeavor to keep you informed of
|
||||
the progress being made towards a fix and announcement.
|
||||
the progress being made towards a fix and announcement.
|
||||
|
||||
If you believe that an existing (public) issue is security-related, please send
|
||||
an email to `security@tensorflow.org`. The email should include the issue ID and
|
||||
|
@ -1048,7 +1048,10 @@ def set_tf_tensorrt_install_path(environ_cp):
|
||||
|
||||
for lib_file in possible_files:
|
||||
if is_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
ver_str = nvinfer_pattern.search(lib_file).group(1)
|
||||
matches = nvinfer_pattern.search(lib_file)
|
||||
if len(matches.groups()) == 0:
|
||||
continue
|
||||
ver_str = matches.group(1)
|
||||
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
|
||||
if ver > highest_ver[0]:
|
||||
highest_ver = [ver, ver_str, lib_file]
|
||||
@ -1377,7 +1380,7 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
check_bazel_version('0.5.4')
|
||||
check_bazel_version('0.10.0')
|
||||
|
||||
reset_tf_configure_bazelrc(args.workspace)
|
||||
cleanup_makefile()
|
||||
|
@ -38,14 +38,7 @@ namespace xla {
|
||||
|
||||
GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
|
||||
size_t pointer_size)
|
||||
: platform_id_(platform_id), pointer_size_(pointer_size) {
|
||||
// We currently only support kHostPlatformId for CPU, kCudaPlatformId for
|
||||
// GPU and kInterpreterPlatformId for Interpreter. Before supporting other
|
||||
// platforms, we need to test this transfer manager on them.
|
||||
CHECK(platform_id_ == se::host::kHostPlatformId ||
|
||||
platform_id_ == se::interpreter::kInterpreterPlatformId ||
|
||||
platform_id_ == se::cuda::kCudaPlatformId);
|
||||
}
|
||||
: platform_id_(platform_id), pointer_size_(pointer_size) {}
|
||||
|
||||
se::Platform::Id GenericTransferManager::PlatformId() const {
|
||||
return platform_id_;
|
||||
|
@ -723,7 +723,7 @@ INSTANTIATE_TEST_CASE_P(
|
||||
);
|
||||
#endif
|
||||
|
||||
TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
|
||||
|
@ -8,6 +8,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
load("//third_party/mpi:mpi.bzl", "if_mpi")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
|
||||
py_library(
|
||||
name = "contrib_py",
|
||||
@ -51,7 +52,6 @@ py_library(
|
||||
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/kafka",
|
||||
"//tensorflow/contrib/keras",
|
||||
"//tensorflow/contrib/kernel_methods",
|
||||
"//tensorflow/contrib/kfac",
|
||||
@ -63,7 +63,6 @@ py_library(
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/lite/python:lite",
|
||||
"//tensorflow/contrib/lookup:lookup_py",
|
||||
"//tensorflow/contrib/losses:losses_py",
|
||||
"//tensorflow/contrib/losses:metric_learning_py",
|
||||
@ -110,6 +109,10 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([
|
||||
"//tensorflow/contrib/tensorrt:init_py",
|
||||
]) + if_not_windows([
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", # unix dependency, need to fix code
|
||||
"//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
|
||||
"//tensorflow/contrib/kafka", # has some linking issue on opensssl.
|
||||
]),
|
||||
)
|
||||
|
||||
@ -121,6 +124,7 @@ cc_library(
|
||||
"//tensorflow/contrib/coder:all_kernels",
|
||||
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels",
|
||||
"//tensorflow/contrib/data/kernels:dataset_kernels",
|
||||
"//tensorflow/contrib/kafka:dataset_kernels",
|
||||
"//tensorflow/contrib/factorization/kernels:all_kernels",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
@ -147,7 +151,7 @@ cc_library(
|
||||
"//tensorflow/contrib/factorization:all_ops",
|
||||
"//tensorflow/contrib/framework:all_ops",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||
"//tensorflow/contrib/kafka:kafka_ops_op_lib",
|
||||
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import batching
|
||||
from tensorflow.contrib import bayesflow
|
||||
@ -83,7 +85,8 @@ from tensorflow.contrib import tpu
|
||||
from tensorflow.contrib import training
|
||||
from tensorflow.contrib import util
|
||||
from tensorflow.contrib.eager.python import tfe as eager
|
||||
from tensorflow.contrib.lite.python import lite
|
||||
if os.name != 'nt':
|
||||
from tensorflow.contrib.lite.python import lite
|
||||
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
|
||||
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
|
||||
from tensorflow.contrib.specs import python as specs
|
||||
@ -92,6 +95,7 @@ from tensorflow.contrib.summary import summary
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
ffmpeg = LazyLoader("ffmpeg", globals(),
|
||||
"tensorflow.contrib.ffmpeg")
|
||||
del os
|
||||
del LazyLoader
|
||||
|
||||
del absolute_import
|
||||
|
@ -48,9 +48,9 @@ class BatchFeatures {
|
||||
Status GetFeatureColumnSizes(int64* const num_dense_float_features,
|
||||
int64* const num_sparse_float_features,
|
||||
int64* const num_sparse_int_features) const {
|
||||
QCHECK_NE(num_dense_float_features, nullptr);
|
||||
QCHECK_NE(num_sparse_float_features, nullptr);
|
||||
QCHECK_NE(num_sparse_int_features, nullptr);
|
||||
QCHECK_NE(num_dense_float_features, (int64*) nullptr);
|
||||
QCHECK_NE(num_sparse_float_features, (int64*) nullptr);
|
||||
QCHECK_NE(num_sparse_int_features, (int64*) nullptr);
|
||||
*num_dense_float_features = dense_float_feature_columns_.size();
|
||||
*num_sparse_float_features = sparse_float_feature_columns_.size();
|
||||
*num_sparse_int_features = sparse_int_feature_columns_.size();
|
||||
|
@ -26,7 +26,7 @@ The CMake files in this directory can build the core TensorFlow runtime, an
|
||||
example C++ binary, and a PIP package containing the runtime and Python
|
||||
bindings.
|
||||
|
||||
### Pre-requisites
|
||||
### Prerequisites
|
||||
|
||||
* CMake version 3.5 or later.
|
||||
|
||||
@ -34,14 +34,16 @@ bindings.
|
||||
|
||||
* [SWIG](http://www.swig.org/download.html)
|
||||
|
||||
* Additional pre-requisites for Microsoft Windows:
|
||||
* Additional prerequisites for Microsoft Windows:
|
||||
- Visual Studio 2015
|
||||
- Python 3.5
|
||||
- NumPy 1.11.0 or later
|
||||
|
||||
* Additional pre-requisites for Linux:
|
||||
* Additional prerequisites for Linux:
|
||||
- Python 2.7 or later
|
||||
- [Docker](https://www.docker.com/) (for automated testing)
|
||||
|
||||
* Python dependencies:
|
||||
- wheel
|
||||
- NumPy 1.11.0 or later
|
||||
|
||||
### Known-good configurations
|
||||
@ -102,7 +104,7 @@ ops or APIs.
|
||||
Step-by-step Windows build
|
||||
==========================
|
||||
|
||||
1. Install the pre-requisites detailed above, and set up your environment.
|
||||
1. Install the prerequisites detailed above, and set up your environment.
|
||||
|
||||
* The following commands assume that you are using the Windows Command
|
||||
Prompt (`cmd.exe`). You will need to set up your environment to use the
|
||||
|
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
@ -35,6 +35,7 @@ else()
|
||||
set(grpc_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
||||
endif()
|
||||
|
||||
|
@ -16,7 +16,7 @@ include (ExternalProject)
|
||||
|
||||
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
|
||||
set(PROTOBUF_URL https://github.com/google/protobuf.git)
|
||||
set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a)
|
||||
set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9)
|
||||
|
||||
if(WIN32)
|
||||
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
||||
|
@ -476,6 +476,10 @@ if (tensorflow_BUILD_CC_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*_test.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM tf_test_src_simple
|
||||
${tf_core_profiler_test_srcs}
|
||||
)
|
||||
|
||||
set(tf_test_lib tf_test_lib)
|
||||
add_library(${tf_test_lib} STATIC ${tf_src_testlib})
|
||||
|
||||
|
@ -40,6 +40,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
|
||||
@@rejection_resample
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@@sliding_window_batch
|
||||
@@sloppy_interleave
|
||||
@@unbatch
|
||||
|
||||
@ -72,6 +73,9 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
|
||||
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
||||
from tensorflow.contrib.data.python.ops.scan_ops import scan
|
||||
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
|
||||
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
|
||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||
from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_single_example
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -498,6 +498,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "slide_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["slide_dataset_op_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -0,0 +1,242 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import sliding
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SlideDatasetTest(test.TestCase):
|
||||
|
||||
def testSlideDataset(self):
|
||||
"""Test an dataset that maps a TF function across its input elements."""
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7))
|
||||
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
window_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stride = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||
# RepeatDataset(count) -> _SlideDataset(window_size, stride).
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(_map_fn)
|
||||
.repeat(count)
|
||||
.apply(sliding.sliding_window_batch(window_size, stride))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Slide over a finite input, where the window_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
|
||||
# Same formula with convolution layer.
|
||||
num_batches = (20 * 7 - 14) // 7 + 1
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i*7 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, where the window_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
|
||||
|
||||
num_batches = (20 * 7 - 17) // 9 + 1
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(17):
|
||||
self.assertAllEqual(component[(i*9 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, which is less than window_size,
|
||||
# should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Empty window_size should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0})
|
||||
|
||||
# Invalid stride should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})
|
||||
|
||||
def assertSparseValuesEqual(self, a, b):
|
||||
self.assertAllEqual(a.indices, b.indices)
|
||||
self.assertAllEqual(a.values, b.values)
|
||||
self.assertAllEqual(a.dense_shape, b.dense_shape)
|
||||
|
||||
def testSlideSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
|
||||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||
dense_shape=[5, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSlideSparseWithDifferentDenseShapes(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=array_ops.expand_dims(
|
||||
math_ops.range(i, dtype=dtypes.int64), 1),
|
||||
values=array_ops.fill([math_ops.to_int32(i)], i),
|
||||
dense_shape=[i])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
|
||||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
for k in range(i * 3 + j):
|
||||
expected_indices.append([j, k])
|
||||
expected_values.append(i * 3 + j)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=expected_indices,
|
||||
values=expected_values,
|
||||
dense_shape=[5, i * 3 + 5 - 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNestedSlideSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(_sparse)
|
||||
.apply(sliding.sliding_window_batch(4, 2))
|
||||
.apply(sliding.sliding_window_batch(3, 1))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
# Slide: 1st batch.
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
|
||||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
|
||||
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
# Slide: 2nd batch.
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
|
||||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
|
||||
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSlideShapeError(self):
|
||||
|
||||
def generator():
|
||||
yield [1.0, 2.0, 3.0]
|
||||
yield [4.0, 5.0, 6.0]
|
||||
yield [7.0, 8.0, 9.0, 10.0]
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
|
||||
output_shapes=[None])
|
||||
.apply(sliding.sliding_window_batch(3, 1))
|
||||
.make_initializable_iterator())
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r"Cannot batch tensors with different shapes in component 0. "
|
||||
r"First element had shape \[3\] and element 2 had shape \[4\]."):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -106,6 +106,7 @@ py_library(
|
||||
"interleave_ops.py",
|
||||
"resampling.py",
|
||||
"scan_ops.py",
|
||||
"sliding.py",
|
||||
"stats_ops.py",
|
||||
"threadpool.py",
|
||||
"unique.py",
|
||||
|
102
tensorflow/contrib/data/python/ops/sliding.py
Normal file
102
tensorflow/contrib/data/python/ops/sliding.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Sliding dataset transformations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
class _SlideDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` that passes a sliding window over its input."""
|
||||
|
||||
def __init__(self, input_dataset, window_size, stride=1):
|
||||
"""See `sliding_window_batch` for details."""
|
||||
super(_SlideDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._window_size = ops.convert_to_tensor(
|
||||
window_size, dtype=dtypes.int64, name="window_size")
|
||||
self._stride = ops.convert_to_tensor(
|
||||
stride, dtype=dtypes.int64, name="stride")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.slide_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
stride=self._stride,
|
||||
output_shapes=nest.flatten(
|
||||
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
|
||||
output_types=nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)))
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
input_shapes = self._input_dataset.output_shapes
|
||||
return nest.pack_sequence_as(input_shapes, [
|
||||
tensor_shape.vector(None).concatenate(s)
|
||||
for s in nest.flatten(self._input_dataset.output_shapes)
|
||||
])
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
|
||||
def sliding_window_batch(window_size, stride=1):
|
||||
"""A sliding window with size of `window_size` and step of `stride`.
|
||||
|
||||
This transformation passes a sliding window over this dataset. The
|
||||
window size is `window_size` and step size is `stride`. If the left
|
||||
elements cannot fill up the sliding window, this transformation will
|
||||
drop the final smaller element. For example:
|
||||
|
||||
```python
|
||||
# NOTE: The following examples use `{ ... }` to represent the
|
||||
# contents of a dataset.
|
||||
a = { [1], [2], [3], [4], [5], [6] }
|
||||
|
||||
a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) ==
|
||||
{
|
||||
[[1], [2], [3]],
|
||||
[[3], [4], [5]],
|
||||
}
|
||||
```
|
||||
|
||||
Args:
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
elements in the sliding window.
|
||||
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
|
||||
steps moving the sliding window forward for one iteration. The default
|
||||
is `1`. It must be in `[1, window_size)`.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _SlideDataset(dataset, window_size, stride)
|
||||
|
||||
return _apply_fn
|
@ -454,6 +454,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
@ -1143,6 +1144,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -266,7 +266,10 @@ cuda_py_test(
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
tags = [
|
||||
"no_windows", # TODO: needs investigation on Windows
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -22,6 +22,7 @@ cuda_py_test(
|
||||
":linear_regression",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -224,7 +224,10 @@ py_test(
|
||||
srcs = ["python/ops/kmeans_test.py"],
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"], # b/67512932
|
||||
tags = [
|
||||
"nomac", # b/73741358
|
||||
"notsan", # b/67512932
|
||||
],
|
||||
deps = [
|
||||
":factorization_py",
|
||||
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||
|
@ -256,6 +256,9 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
|
||||
if (p != std::string::npos) {
|
||||
string rgb24 = line.substr(p + 9, line.find(" ", p + 9));
|
||||
rgb24 = rgb24.substr(0, rgb24.find(","));
|
||||
// Strip anything after " ", in case the format is
|
||||
// `640x360 [SAR 1:1 DAR 16:9]`
|
||||
rgb24 = rgb24.substr(0, rgb24.find(" "));
|
||||
string rgb24_width = rgb24.substr(0, rgb24.find("x"));
|
||||
string rgb24_height = rgb24.substr(rgb24_width.length() + 1);
|
||||
if (strings::safe_strtou32(rgb24_width, &width_value) &&
|
||||
@ -270,8 +273,10 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
|
||||
// We only look for the first stream mapping to have the number of the
|
||||
// frames.
|
||||
// Once processed we will not further process stream mapping section.
|
||||
if (line.find("frame= ") == 0) {
|
||||
string number = line.substr(8, line.find(" ", 8));
|
||||
if (line.find("frame=") == 0) {
|
||||
// The format might be `frame= 166 ` or `frame=12488 `
|
||||
string number = line.substr(6);
|
||||
number = number.substr(number.find_first_not_of(" "));
|
||||
number = number.substr(0, number.find(" "));
|
||||
if (strings::safe_strtou32(number, &frames_value)) {
|
||||
in_mapping = false;
|
||||
|
@ -354,6 +354,7 @@ py_test(
|
||||
name = "classifier_metrics_test",
|
||||
srcs = ["python/eval/python/classifier_metrics_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":classifier_metrics",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -1,66 +1,93 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
tf_kernel_library(
|
||||
name = "kafka_kernels",
|
||||
srcs = ["kernels/kafka_dataset_ops.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:dataset",
|
||||
"//third_party/eigen3",
|
||||
"@kafka",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["kafka_ops"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_kafka_ops",
|
||||
out = "python/ops/gen_kafka_ops.py",
|
||||
require_shape_functions = True,
|
||||
deps = [":kafka_ops_op_lib"],
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
"tf_custom_op_library",
|
||||
"tf_custom_op_py_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_py_test",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kafka",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "_dataset_ops.so",
|
||||
srcs = ["ops/dataset_ops.cc"],
|
||||
deps = [":dataset_kernels"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["dataset_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dataset_kernels",
|
||||
srcs = ["kernels/kafka_dataset_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@kafka",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset_ops",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/kafka_dataset_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gen_kafka_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
":kafka_op_loader",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_dataset_ops",
|
||||
out = "python/ops/gen_dataset_ops.py",
|
||||
deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dataset_ops_kernels",
|
||||
deps = [
|
||||
":dataset_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "kafka_op_loader",
|
||||
srcs = ["python/ops/kafka_op_loader.py"],
|
||||
dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"],
|
||||
kernels = [
|
||||
":dataset_ops_kernels",
|
||||
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_dataset_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
@ -88,6 +115,7 @@ tf_py_test(
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"no_windows",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
@ -95,7 +123,9 @@ tf_py_test(
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
include = [
|
||||
"**/*",
|
||||
],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
|
@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/dataset.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
#include "src-cpp/rdkafkacpp.h"
|
||||
|
||||
|
44
tensorflow/contrib/kafka/ops/dataset_ops.cc
Normal file
44
tensorflow/contrib/kafka/ops/dataset_ops.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("KafkaDataset")
|
||||
.Input("topics: string")
|
||||
.Input("servers: string")
|
||||
.Input("group: string")
|
||||
.Input("eof: bool")
|
||||
.Input("timeout: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Creates a dataset that emits the messages of one or more Kafka topics.
|
||||
|
||||
topics: A `tf.string` tensor containing one or more subscriptions,
|
||||
in the format of [topic:partition:offset:length],
|
||||
by default length is -1 for unlimited.
|
||||
servers: A list of bootstrap servers.
|
||||
group: The consumer group id.
|
||||
eof: If True, the kafka reader will stop on EOF.
|
||||
timeout: The timeout value for the Kafka Consumer to wait
|
||||
(in millisecond).
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -17,8 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kafka.python.ops import gen_kafka_ops
|
||||
from tensorflow.python.data.ops.readers import Dataset
|
||||
from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
|
||||
from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.data.ops.dataset_ops import Dataset
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -58,8 +59,8 @@ class KafkaDataset(Dataset):
|
||||
timeout, dtype=dtypes.int64, name="timeout")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group,
|
||||
self._eof, self._timeout)
|
||||
return gen_dataset_ops.kafka_dataset(self._topics, self._servers,
|
||||
self._group, self._eof, self._timeout)
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
|
24
tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
Normal file
24
tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python helper for loading kafka ops and kernels."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_dataset_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
|
@ -114,6 +114,7 @@ py_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/contrib/tpu",
|
||||
|
@ -70,6 +70,7 @@ py_test(
|
||||
"python/ops/core_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":_typecheck",
|
||||
":core",
|
||||
|
@ -188,6 +188,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/layers/normalization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":layers_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
@ -353,6 +354,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/ops/sparse_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -470,7 +470,7 @@ def embedding_lookup_unique(params, ids, name=None):
|
||||
ids = ops.convert_to_tensor(ids)
|
||||
shape = array_ops.shape(ids)
|
||||
ids_flat = array_ops.reshape(
|
||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||
ids, math_ops.reduce_prod(shape, keepdims=True))
|
||||
unique_ids, idx = array_ops.unique(ids_flat)
|
||||
unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
|
||||
embeds_flat = array_ops.gather(unique_embeddings, idx)
|
||||
|
@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(default_visibility = [
|
||||
"//engedu/ml/tf_from_scratch:__pkg__",
|
||||
"//tensorflow:internal",
|
||||
@ -115,6 +117,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/learn/learn_io/data_feeder_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -170,6 +173,7 @@ tf_py_test(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -188,6 +192,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/learn/graph_actions_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
@ -426,7 +431,10 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/kmeans_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["noasan"],
|
||||
tags = [
|
||||
"noasan", # b/73741358
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -585,6 +593,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/learn/learn_io/io_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/contrib/learn/python/learn/datasets",
|
||||
@ -814,6 +823,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/learn/utils/saved_model_export_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
|
@ -61,7 +61,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'):
|
||||
ids = ops.convert_to_tensor(ids)
|
||||
shape = array_ops_.shape(ids)
|
||||
ids_flat = array_ops_.reshape(
|
||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||
ids, math_ops.reduce_prod(shape, keepdims=True))
|
||||
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
|
||||
embed_shape = array_ops_.concat([shape, [-1]], 0)
|
||||
embeds = array_ops_.reshape(embeds_flat, embed_shape)
|
||||
|
@ -27,10 +27,10 @@ LIBDIR := $(MAKEFILE_DIR)/gen/lib/
|
||||
GENDIR := $(MAKEFILE_DIR)/gen/obj/
|
||||
|
||||
# Settings for the host compiler.
|
||||
CXX := $(CC_PREFIX) gcc
|
||||
CXX := $(CC_PREFIX)gcc
|
||||
CXXFLAGS := --std=c++11 -O3 -DNDEBUG
|
||||
CC := $(CC_PREFIX) gcc
|
||||
CFLAGS :=
|
||||
CC := $(CC_PREFIX)gcc
|
||||
CFLAGS := -O3 -DNDEBUG
|
||||
LDOPTS :=
|
||||
LDOPTS += -L/usr/local/lib
|
||||
ARFLAGS := -r
|
||||
@ -57,10 +57,11 @@ LIBS := \
|
||||
|
||||
# If we're on Linux, also link in the dl library.
|
||||
ifeq ($(HOST_OS),LINUX)
|
||||
LIBS += -ldl -lpthread
|
||||
LIBS += -ldl
|
||||
endif
|
||||
|
||||
include $(MAKEFILE_DIR)/ios_makefile.inc
|
||||
include $(MAKEFILE_DIR)/rpi_makefile.inc
|
||||
|
||||
# This library is the main target for this makefile. It will contain a minimal
|
||||
# runtime that can be linked in to other programs.
|
||||
|
@ -33,7 +33,7 @@ class AllocationInfo;
|
||||
// each tensor needs to be allocated and deallocated, and preallocates all the
|
||||
// necessary memory (the PlanAllocations phase). It then assigns portions of
|
||||
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
|
||||
// share some of the bufer if a tensor B is to be allocated after another tensor
|
||||
// share some of the buffer if a tensor B is to be allocated after another tensor
|
||||
// A has been deallocated.
|
||||
//
|
||||
// If dynamic tensors are used the planning steps can be repeated during model
|
||||
|
22
tensorflow/contrib/lite/build_rpi_lib.sh
Executable file
22
tensorflow/contrib/lite/build_rpi_lib.sh
Executable file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash -x
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR/../../.."
|
||||
|
||||
CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7
|
@ -24,7 +24,7 @@ extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// The enum for builtin operators.
|
||||
// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin
|
||||
// Note: CUSTOM and DELEGATE are 2 special ops which are not real builtin
|
||||
// ops.
|
||||
typedef enum {
|
||||
kTfLiteBuiltinAdd = 0,
|
||||
|
@ -30,7 +30,7 @@ namespace tflite {
|
||||
// va_list args;
|
||||
// foo.Report("test %d", args); // where args is va_list
|
||||
//
|
||||
// Sublclass ErrorReporter to provide another reporting destination.
|
||||
// Subclass ErrorReporter to provide another reporting destination.
|
||||
// For example, if you have a GUI program, you might redirect to a buffer
|
||||
// that drives a GUI error log box.
|
||||
class ErrorReporter {
|
||||
|
50
tensorflow/contrib/lite/g3doc/rpi.md
Normal file
50
tensorflow/contrib/lite/g3doc/rpi.md
Normal file
@ -0,0 +1,50 @@
|
||||
# TensorFlow Lite for Raspberry Pi
|
||||
|
||||
## Cross compiling
|
||||
### Installing toolchian
|
||||
This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
|
||||
|
||||
To cross compiling TensorFlow Lite. First you should install the toolchain and libs.
|
||||
```bash
|
||||
sudo apt-get update
|
||||
sudo apt-get install crossbuild-essential-armhf
|
||||
```
|
||||
> If you are using docker, you may not use `sudo`
|
||||
|
||||
### Building
|
||||
Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
|
||||
> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
|
||||
```bash
|
||||
./tensorflow/contrib/lite/download_dependencies.sh
|
||||
```
|
||||
Note than you only need to to this once.
|
||||
|
||||
You should then be able to compile:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/build_rpi_lib.sh
|
||||
```
|
||||
|
||||
This should compile a static library in:
|
||||
`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
|
||||
|
||||
## Native compiling
|
||||
This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
|
||||
|
||||
Log in to you RPI, install the toolchain.
|
||||
```bash
|
||||
sudo apt-get instal build-essential
|
||||
```
|
||||
|
||||
First, clone this TensorFlow repository. Run this at the root of the repository:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/download_dependencies.sh
|
||||
```
|
||||
Note than you only need to to this once.
|
||||
|
||||
You should then be able to compile:
|
||||
```bash
|
||||
./tensorflow/contrib/lite/build_rpi_lib.sh
|
||||
```
|
||||
|
||||
This should compile a static library in:
|
||||
`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
|
@ -481,7 +481,7 @@ class Interpreter {
|
||||
// During Invoke(), Interpreter will allocate input tensors first, which are
|
||||
// known to be fixed size. Then it will allocate outputs from nodes as many
|
||||
// as possible. When there is a node that produces dynamic sized tensor.
|
||||
// Intepreter will stop allocating tensors, set the value of next allocate
|
||||
// Interpreter will stop allocating tensors, set the value of next allocate
|
||||
// node id, and execute the node to generate the output tensor before continue
|
||||
// to allocate successors. This process repeats until all nodes are executed.
|
||||
// NOTE: this relies on the order of nodes that is in topological order.
|
||||
|
@ -40,7 +40,7 @@ TEST(BasicInterpreter, InvokeInvalidModel) {
|
||||
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
// Test size accesser functions.
|
||||
// Test size accessor functions.
|
||||
TEST(BasicInterpreter, TestSizeFunctions) {
|
||||
Interpreter interpreter;
|
||||
int base_index;
|
||||
|
@ -64,7 +64,7 @@ struct OpData {
|
||||
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -52,7 +52,7 @@ enum KernelType {
|
||||
struct OpData {
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -48,7 +48,7 @@ enum KernelType {
|
||||
|
||||
struct OpData {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multipler plus a left shift.
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
|
@ -58,7 +58,7 @@ inline bool IsConstantTensor(TfLiteTensor* tensor) {
|
||||
}
|
||||
|
||||
// Determines whether tensor is dynamic. Note that a tensor can be non-const and
|
||||
// not dynamic. This function specificially checks for a dynamic tensor.
|
||||
// not dynamic. This function specifically checks for a dynamic tensor.
|
||||
inline bool IsDynamicTensor(TfLiteTensor* tensor) {
|
||||
return tensor->allocation_type == kTfLiteDynamic;
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// LSH Projection projects an input to a bit vector via locality senstive
|
||||
// LSH Projection projects an input to a bit vector via locality sensitive
|
||||
// hashing.
|
||||
//
|
||||
// Options:
|
||||
|
@ -213,9 +213,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
// present.
|
||||
// 2) If projection weight is present, then projection bias is optional.
|
||||
// TODO(ghodrat): make sure this is correct.
|
||||
const bool projecton_tensors_consistent =
|
||||
const bool projection_tensors_consistent =
|
||||
((projection_weights != nullptr) || (projection_bias == nullptr));
|
||||
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
|
||||
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -357,7 +357,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
|
@ -49,20 +49,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
|
||||
int num_output_elements = 1;
|
||||
int strech_dim = -1;
|
||||
int stretch_dim = -1;
|
||||
for (int i = 0; i < params->num_dimensions; ++i) {
|
||||
int value = params->shape[i];
|
||||
if (value == -1) {
|
||||
TF_LITE_ENSURE_EQ(context, strech_dim, -1);
|
||||
strech_dim = i;
|
||||
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
|
||||
stretch_dim = i;
|
||||
} else {
|
||||
num_output_elements *= value;
|
||||
output_size->data[i] = value;
|
||||
}
|
||||
}
|
||||
if (strech_dim != -1) {
|
||||
output_size->data[strech_dim] = num_input_elements / num_output_elements;
|
||||
num_output_elements *= output_size->data[strech_dim];
|
||||
if (stretch_dim != -1) {
|
||||
output_size->data[stretch_dim] = num_input_elements / num_output_elements;
|
||||
num_output_elements *= output_size->data[stretch_dim];
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
|
||||
|
@ -60,7 +60,7 @@ TEST(ReshapeOpTest, TooManyDimensions) {
|
||||
|
||||
TEST(ReshapeOpTest, TooManySpecialDimensions) {
|
||||
EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}),
|
||||
"strech_dim != -1");
|
||||
"stretch_dim != -1");
|
||||
}
|
||||
|
||||
TEST(ReshapeOpTest, SimpleTest) {
|
||||
|
@ -141,8 +141,8 @@ void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
|
||||
|
||||
void SingleOpModel::SetCustomOp(
|
||||
const string& name, const std::vector<uint8_t>& custom_option,
|
||||
const std::function<TfLiteRegistration*()>& registeration) {
|
||||
custom_registrations_[name] = registeration;
|
||||
const std::function<TfLiteRegistration*()>& registration) {
|
||||
custom_registrations_[name] = registration;
|
||||
opcodes_.push_back(
|
||||
CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data()));
|
||||
operators_.push_back(CreateOperator(
|
||||
|
@ -360,7 +360,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
|
@ -34,8 +34,8 @@ class MemoryPlanner {
|
||||
// [first_node, last_node].
|
||||
virtual TfLiteStatus ExecuteAllocations(int first_node, int last_node) = 0;
|
||||
|
||||
// Invalidates allocations made earliers. This is called when tensors sizes
|
||||
// have change. All planned allocations remain, but can't be used until
|
||||
// Invalidates allocations made earlier. This is called when tensors sizes
|
||||
// have changed. All planned allocations remain, but can't be used until
|
||||
// ExecuteAllocations() is called.
|
||||
virtual TfLiteStatus ResetAllocations() = 0;
|
||||
};
|
||||
|
@ -64,7 +64,7 @@ class FlatBufferModel {
|
||||
const tflite::Model* model_spec,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
// Releases memory or unmaps mmaped meory.
|
||||
// Releases memory or unmaps mmaped memory.
|
||||
~FlatBufferModel();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
|
@ -569,7 +569,7 @@ enum {
|
||||
ANEURALNETWORKS_LOGISTIC = 14,
|
||||
|
||||
/**
|
||||
* Projects an input to a bit vector via locality senstive hashing.
|
||||
* Projects an input to a bit vector via locality sensitive hashing.
|
||||
*
|
||||
* Inputs:
|
||||
* * 0: Hash functions. Dim.size == 2, DataType: Float.
|
||||
|
33
tensorflow/contrib/lite/rpi_makefile.inc
Normal file
33
tensorflow/contrib/lite/rpi_makefile.inc
Normal file
@ -0,0 +1,33 @@
|
||||
# Settings for Raspberry Pi.
|
||||
ifeq ($(TARGET), RPI)
|
||||
ifeq ($(TARGET_ARCH), armv7)
|
||||
CXXFLAGS += \
|
||||
-march=armv7-a \
|
||||
-mfpu=neon-vfpv4 \
|
||||
-funsafe-math-optimizations \
|
||||
-ftree-vectorize
|
||||
|
||||
CCFLAGS += \
|
||||
-march=armv7-a \
|
||||
-mfpu=neon-vfpv4 \
|
||||
-funsafe-math-optimizations \
|
||||
-ftree-vectorize
|
||||
|
||||
LDFLAGS := \
|
||||
-Wl,--no-export-dynamic \
|
||||
-Wl,--exclude-libs,ALL \
|
||||
-Wl,--gc-sections \
|
||||
-Wl,--as-needed
|
||||
endif
|
||||
|
||||
LIBS := \
|
||||
-lstdc++ \
|
||||
-lpthread \
|
||||
-lm \
|
||||
-ldl
|
||||
|
||||
OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/
|
||||
LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/
|
||||
BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/
|
||||
DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/
|
||||
endif
|
@ -46,7 +46,7 @@ extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// The enum for builtin operators.
|
||||
// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin
|
||||
// Note: CUSTOM and DELEGATE are 2 special ops which are not real builtin
|
||||
// ops.
|
||||
typedef enum {
|
||||
)";
|
||||
|
@ -113,21 +113,21 @@ TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) {
|
||||
underlying_buffer_size_ = required_size;
|
||||
underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr;
|
||||
}
|
||||
commited_ = true;
|
||||
committed_ = true;
|
||||
return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context,
|
||||
const ArenaAlloc& alloc,
|
||||
char** output_ptr) {
|
||||
TF_LITE_ENSURE(context, commited_);
|
||||
TF_LITE_ENSURE(context, committed_);
|
||||
TF_LITE_ENSURE(context, output_ptr != nullptr);
|
||||
*output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SimpleMemoryArena::Clear() {
|
||||
commited_ = false;
|
||||
committed_ = false;
|
||||
high_water_mark_ = 0;
|
||||
allocs_.clear();
|
||||
return kTfLiteOk;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
|
||||
// This little structure holds the offset and the size for a dynamic memory
|
||||
// allocation in the memory arena. When the arena is commited and the
|
||||
// allocation in the memory arena. When the arena is committed and the
|
||||
// underlying buffer is set, the alloc can be resolved into an actual memory
|
||||
// pointer.
|
||||
struct ArenaAlloc {
|
||||
@ -43,7 +43,7 @@ struct ArenaAlloc {
|
||||
class SimpleMemoryArena {
|
||||
public:
|
||||
explicit SimpleMemoryArena(size_t arena_alignment)
|
||||
: commited_(false),
|
||||
: committed_(false),
|
||||
arena_alignment_(arena_alignment),
|
||||
high_water_mark_(0),
|
||||
underlying_buffer_size_(0),
|
||||
@ -73,7 +73,7 @@ class SimpleMemoryArena {
|
||||
}
|
||||
|
||||
private:
|
||||
bool commited_;
|
||||
bool committed_;
|
||||
size_t arena_alignment_;
|
||||
size_t high_water_mark_;
|
||||
std::unique_ptr<char[]> underlying_buffer_;
|
||||
|
@ -46,6 +46,7 @@ tf_py_test(
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -194,6 +194,8 @@ with:
|
||||
srcs = glob(["libs/arm64-v8a/*.so"]),
|
||||
```
|
||||
|
||||
If you are building for Android TV (Shield TV devices), replace "portrait" with "landscape" for android:screenOrientation in all four activities in tensorflow/examples/android/AndroidManifest.xml
|
||||
|
||||
Then run:
|
||||
```bash
|
||||
# Create dir for native libs
|
||||
|
@ -80,10 +80,9 @@ if [[ ! -z "${OPTIMIZE_FOR_GRAPH}" ]]; then
|
||||
fi
|
||||
else
|
||||
echo "${PRNT_SLCTV_BIN} found. Using it"
|
||||
${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h
|
||||
|
||||
fi
|
||||
|
||||
${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h
|
||||
fi
|
||||
|
||||
if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
// Skip MPI C++ bindings support, this matches the usage in other places
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include "third_party/mpi/mpi.h"
|
||||
#define MPI_CHECK(cmd) \
|
||||
do { \
|
||||
|
@ -53,7 +53,7 @@ def from_contrib_estimator(estimator,
|
||||
`Estimator`.
|
||||
"""
|
||||
if isinstance(estimator, core_estimator.Estimator):
|
||||
raise TypeError('Espected estimator to be of type '
|
||||
raise TypeError('Expected estimator to be of type '
|
||||
'tf.contrib.learn.Estimator, but got type '
|
||||
'tf.python.estimator.Estimator. You likely want to call '
|
||||
'from_estimator.')
|
||||
@ -88,7 +88,7 @@ def from_estimator(estimator,
|
||||
`Estimator`.
|
||||
"""
|
||||
if isinstance(estimator, contrib_estimator.Estimator):
|
||||
raise TypeError('Espected estimator to be of type '
|
||||
raise TypeError('Expected estimator to be of type '
|
||||
'tf.python.estimator.Estimator, but got type '
|
||||
'tf.contrib.learn.Estimator. You likely want to call '
|
||||
'from_contrib_estimator.')
|
||||
|
@ -81,6 +81,7 @@ py_test(
|
||||
name = "builtin_functions_test",
|
||||
srcs = ["builtin_functions_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":test_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -91,6 +92,7 @@ py_test(
|
||||
name = "call_trees_test",
|
||||
srcs = ["call_trees_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":test_lib",
|
||||
"//tensorflow/contrib/py2tf/impl",
|
||||
|
@ -212,7 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
|
||||
|
||||
def __init__(self):
|
||||
self.cant_return = False
|
||||
super(gast.NodeVisitor, self).__init__()
|
||||
super(DetectReturnInUnsupportedControlFlow, self).__init__()
|
||||
|
||||
def visit_While(self, node):
|
||||
self.cant_return = True
|
||||
|
@ -83,6 +83,7 @@ py_test(
|
||||
name = "py_func_test",
|
||||
srcs = ["py_func_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -237,7 +237,7 @@ def _FindFusedBatchNorms(graph):
|
||||
# The batch variance used during forward and backward prop is biased,
|
||||
# i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
|
||||
# calculation, the variance is corrected by the term N/N-1 (Bessel's
|
||||
# correction). The variance tensor read from FuseBatchNorm has bessel's
|
||||
# correction). The variance tensor read from FuseBatchNorm has Bessel's
|
||||
# correction applied, so we undo it here.
|
||||
scope, sep, _ = bn_op.name.rpartition('/')
|
||||
g = ops.get_default_graph()
|
||||
@ -306,7 +306,7 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
|
||||
|
||||
Args:
|
||||
context: The scope under which we look for batch norm params
|
||||
match: Object containg required batch norm tensors for correction
|
||||
match: Object containing required batch norm tensors for correction
|
||||
computation.
|
||||
freeze_batch_norm_delay: Delay in steps at which computation switches
|
||||
from regular batch norm to frozen mean and variance.
|
||||
|
@ -282,8 +282,8 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
|
||||
Args:
|
||||
inputs: a tensor containing values to be quantized.
|
||||
min_var: a variable containing quantization range lower end(s).
|
||||
max_var: a variable containing quantization range lupper end(s).
|
||||
per_channel: a boolean specifying whether to use per-channel quantizatioh.
|
||||
max_var: a variable containing quantization range upper end(s).
|
||||
per_channel: a boolean specifying whether to use per-channel quantization.
|
||||
num_bits: Number of bits to use for quantization, must be between 2 and 8.
|
||||
narrow_range: Whether to use the narrow quantization range
|
||||
[1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
|
||||
|
@ -267,7 +267,7 @@ def _InsertQuantOp(context,
|
||||
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||||
|
||||
Args:
|
||||
context: Context w,here producer and consumer operations are nested.
|
||||
context: Context where producer and consumer operations are nested.
|
||||
name: Name for the new quantization op within the context.
|
||||
producer: Producer operation of the pairs where quantization will be
|
||||
inserted.
|
||||
|
@ -158,7 +158,7 @@ def experimental_create_training_graph(input_graph=None,
|
||||
often fail.
|
||||
|
||||
Args:
|
||||
input_graph: The tf.Graph to be transformed,if None then defaults to the
|
||||
input_graph: The tf.Graph to be transformed, if None then defaults to the
|
||||
default graph.
|
||||
weight_bits: Number of bits to use for quantizing weights.
|
||||
activation_bits: Number of bits to use for quantizing activations.
|
||||
|
@ -419,7 +419,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -470,7 +470,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -526,7 +526,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
# Manually add a bypass (optional) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
@ -565,7 +565,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
stddev: Standard deviation of normal variable.
|
||||
|
||||
Returns:
|
||||
An initialized that initialzes with a truncated normal variable.
|
||||
An initialized that initializes with a truncated normal variable.
|
||||
"""
|
||||
return init_ops.truncated_normal_initializer(stddev=stddev)
|
||||
|
||||
|
@ -144,7 +144,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
stddev: Standard deviation of normal variable.
|
||||
|
||||
Returns:
|
||||
An initialized that initialzes with a truncated normal variable.
|
||||
An initialized that initializes with a truncated normal variable.
|
||||
"""
|
||||
return init_ops.truncated_normal_initializer(stddev=stddev)
|
||||
|
||||
|
@ -38,7 +38,6 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/ops/remote_fused_graph_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"],
|
||||
deps = [
|
||||
":remote_fused_graph_ops_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -2133,7 +2133,7 @@ class Conv1DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
|
||||
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
|
||||
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)
|
||||
|
||||
|
||||
class Conv2DLSTMCell(ConvLSTMCell):
|
||||
@ -2144,7 +2144,7 @@ class Conv2DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
|
||||
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
|
||||
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)
|
||||
|
||||
|
||||
class Conv3DLSTMCell(ConvLSTMCell):
|
||||
@ -2155,7 +2155,7 @@ class Conv3DLSTMCell(ConvLSTMCell):
|
||||
|
||||
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
|
||||
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)
|
||||
|
||||
|
||||
def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
||||
|
@ -53,6 +53,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/saved_model/reader_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":saved_model_py",
|
||||
|
@ -299,12 +299,13 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"""
|
||||
finished, start_inputs = self._finished, self._start_inputs
|
||||
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=0.0,
|
||||
off_value=-np.Inf,
|
||||
dtype=nest.flatten(self._initial_cell_state)[0].dtype)
|
||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||
dtype=dtype)
|
||||
|
||||
initial_state = BeamSearchDecoderState(
|
||||
cell_state=self._initial_cell_state,
|
||||
|
@ -165,6 +165,7 @@ py_test(
|
||||
name = "gc_test",
|
||||
srcs = ["gc_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":gc",
|
||||
|
@ -61,6 +61,7 @@ py_test(
|
||||
name = "dataset_data_provider_test",
|
||||
srcs = ["dataset_data_provider_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
":dataset",
|
||||
":dataset_data_provider",
|
||||
|
@ -553,7 +553,6 @@ py_test(
|
||||
srcs = ["client/random_forest_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
"nomac", # b/63258195
|
||||
"notsan",
|
||||
],
|
||||
|
@ -9,6 +9,7 @@ exports_files(["LICENSE"])
|
||||
|
||||
# For platform specific build config
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
tf_proto_library(
|
||||
name = "protos_all",
|
||||
|
@ -83,6 +83,7 @@ cc_library(
|
||||
"kernels/trt_engine_op.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":trt_logging",
|
||||
":trt_resources",
|
||||
@ -154,6 +155,7 @@ py_library(
|
||||
deps = [
|
||||
":trt_convert_py",
|
||||
":trt_ops_py",
|
||||
"//tensorflow/python:errors",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -2,7 +2,8 @@ Using TensorRT in TensorFlow
|
||||
============================
|
||||
|
||||
This module provides necessary bindings and introduces TRT_engine_op
|
||||
operator that wraps a subgraph in TensorRT.
|
||||
operator that wraps a subgraph in TensorRT. This is still a work in progress
|
||||
but should be useable with most common graphs.
|
||||
|
||||
Compilation
|
||||
-----------
|
||||
@ -15,26 +16,10 @@ configure script should find the necessary components from the system
|
||||
automatically. If installed from tar packages, user has to set path to
|
||||
location where the library is installed during configuration.
|
||||
|
||||
|
||||
```
|
||||
```shell
|
||||
bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
|
||||
```
|
||||
|
||||
After the installation of tensorflow package, TensorRT transformation
|
||||
will be available. An example use is shown below.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.tensorrt as trt
|
||||
#... create and train or load model
|
||||
gdef = sess.graph.as_graph_def()
|
||||
trt_gdef = trt.create_inference_graph(
|
||||
gdef, #original graph_def
|
||||
["output"], #name of output node(s)
|
||||
max_batch_size, #maximum batch size to run the inference
|
||||
max_workspace_size_bytes) # max memory for TensorRT to use
|
||||
tf.reset_default_graph()
|
||||
tf.import_graph_def(graph_def=trt_gdef)
|
||||
#...... run inference
|
||||
```
|
||||
will be available. An example use can be found in test/test_tftrt.py directory
|
||||
|
@ -18,6 +18,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.tensorrt.python import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,g-import-not-at-top
|
||||
try:
|
||||
from tensorflow.contrib.tensorrt.python import *
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
'**** Failed to initialize TensorRT. This is either because the TensorRT'
|
||||
' installation path is not in LD_LIBRARY_PATH, or because you do not have'
|
||||
' it installed. If not installed, please go to'
|
||||
' https://developer.nvidia.com/tensorrt to download and install'
|
||||
' TensorRT ****')
|
||||
print(no_trt_message)
|
||||
raise e
|
||||
# pylint: enable=unused-import,wildcard-import,g-import-not-at-top
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
@ -48,13 +49,29 @@ namespace tensorrt {
|
||||
namespace convert {
|
||||
namespace {
|
||||
|
||||
static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||
bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||
// LINT.IfChange
|
||||
// TODO(jie): Segmentation shouldn't associated with op name.
|
||||
// Split it into a registration for each kernel.
|
||||
static const std::set<string> candidate_ops = {
|
||||
"Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
|
||||
"Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
|
||||
"Identity",
|
||||
"Const",
|
||||
"Conv2D",
|
||||
"MaxPool",
|
||||
"BiasAdd",
|
||||
"Relu",
|
||||
"Add",
|
||||
"Mul",
|
||||
"Sub",
|
||||
"Rsqrt",
|
||||
"Pad",
|
||||
"Mean",
|
||||
"AvgPool",
|
||||
"ConcatV2",
|
||||
"DepthwiseConv2dNative",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
// TODO(ben,jie): ...
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
|
||||
return candidate_ops.count(node_def.op());
|
||||
@ -69,6 +86,8 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
|
||||
if (!subgraph_node_ids.count(edge->src()->id()) &&
|
||||
!edge->src()->IsSource()) {
|
||||
incoming_edges->insert(edge);
|
||||
} else {
|
||||
VLOG(2) << edge->src()->name() << " N, ";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -82,7 +101,10 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
|
||||
for (const tensorflow::Edge* edge : node->out_edges()) {
|
||||
if (!subgraph_node_ids.count(edge->dst()->id()) &&
|
||||
!edge->dst()->IsSink()) {
|
||||
VLOG(2) << edge->dst()->name() << " Y, ";
|
||||
outgoing_edges->insert(edge);
|
||||
} else {
|
||||
VLOG(2) << edge->dst()->name() << " N, ";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,74 +131,150 @@ std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRT(
|
||||
const std::vector<string>& output_names,
|
||||
const std::set<int>& subgraph_node_ids,
|
||||
size_t max_batch_size, // Max batch size that engine will be created for
|
||||
// Max amount of memory that engine will be allowed to consume, in bytes
|
||||
size_t max_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& graph_properties,
|
||||
tensorflow::Graph* graph) {
|
||||
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||
GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
|
||||
|
||||
// TODO(sami): convert references to pointers
|
||||
struct ConvertGraphParams {
|
||||
ConvertGraphParams(
|
||||
tensorflow::Graph& inp_graph,
|
||||
const std::vector<string>& output_node_names,
|
||||
const std::set<int>& subgraph_node_id_numbers,
|
||||
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& current_graph_properties,
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edges,
|
||||
int engine_precision_mode)
|
||||
: graph(inp_graph),
|
||||
output_names(output_node_names),
|
||||
subgraph_node_ids(subgraph_node_id_numbers),
|
||||
max_batch_size(max_supported_batch_size),
|
||||
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
|
||||
graph_properties(current_graph_properties),
|
||||
output_edge_map(output_edges),
|
||||
precision_mode(engine_precision_mode) {}
|
||||
tensorflow::Graph& graph;
|
||||
const std::vector<string>& output_names;
|
||||
const std::set<int>& subgraph_node_ids;
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
const tensorflow::grappler::GraphProperties& graph_properties;
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
|
||||
int precision_mode;
|
||||
std::vector<std::pair<int, int>> subgraph_inputs;
|
||||
std::vector<std::pair<int, int>> subgraph_outputs;
|
||||
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||
};
|
||||
|
||||
// Collect inputs by looking for incoming edges
|
||||
for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
|
||||
subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||
static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
|
||||
GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids,
|
||||
&p->subgraph_incoming_edges);
|
||||
for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) {
|
||||
p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
auto output_name_to_index_map = BuildTensorNameMap(p->output_names);
|
||||
std::set<std::pair<int, int>> subgraph_outputs_set;
|
||||
// Collect outputs referenced from output_names
|
||||
auto output_name_to_index_map = BuildTensorNameMap(output_names);
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph->FindNodeId(node_id);
|
||||
for (int node_id : p->subgraph_node_ids) {
|
||||
tensorflow::Node* node = p->graph.FindNodeId(node_id);
|
||||
if (output_name_to_index_map.count(node->name())) {
|
||||
for (int index : output_name_to_index_map.at(node->name())) {
|
||||
subgraph_outputs_set.insert({node_id, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect outputs referenced from outgoing edges
|
||||
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||
GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
|
||||
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
|
||||
GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids,
|
||||
&p->subgraph_outgoing_edges);
|
||||
for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) {
|
||||
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
// Impose an ordering on the outputs
|
||||
std::vector<std::pair<int, int>> subgraph_outputs(
|
||||
subgraph_outputs_set.begin(), subgraph_outputs_set.end());
|
||||
// Build TensorRT node and add it to the graph
|
||||
p->subgraph_outputs.reserve(subgraph_outputs_set.size());
|
||||
p->subgraph_outputs.insert(p->subgraph_outputs.begin(),
|
||||
subgraph_outputs_set.begin(),
|
||||
subgraph_outputs_set.end());
|
||||
return tensorflow::Status::OK();
|
||||
};
|
||||
|
||||
tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
|
||||
TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
|
||||
tensorflow::NodeDef trt_node_def;
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
|
||||
*graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
|
||||
max_batch_size, max_workspace_size_bytes, graph_properties,
|
||||
&trt_node_def));
|
||||
SubGraphParams s(params->graph, params->subgraph_node_ids,
|
||||
params->subgraph_inputs, params->subgraph_outputs,
|
||||
params->max_batch_size, params->max_workspace_size_bytes,
|
||||
params->graph_properties, params->output_edge_map,
|
||||
&trt_node_def, params->precision_mode);
|
||||
TF_RETURN_IF_ERROR(InjectCalibrationNode(s));
|
||||
tensorflow::Status status;
|
||||
tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
|
||||
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
for (auto in_edge :
|
||||
params->subgraph_incoming_edges) { // loop over incoming edges and
|
||||
// attach them to calib node
|
||||
// tensorflow::Node* src_node = in_edge->src();
|
||||
auto src_output = in_edge->src_output();
|
||||
auto dst_node = in_edge->dst();
|
||||
auto dst_input = in_edge->dst_input();
|
||||
VLOG(1) << " update edge " << trt_node->name() << ":" << src_output
|
||||
<< " -> " << dst_node->name() << ":" << dst_input;
|
||||
TF_RETURN_IF_ERROR(
|
||||
params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
|
||||
TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
|
||||
tensorflow::NodeDef trt_node_def;
|
||||
|
||||
SubGraphParams s(params->graph, params->subgraph_node_ids,
|
||||
params->subgraph_inputs, params->subgraph_outputs,
|
||||
params->max_batch_size, params->max_workspace_size_bytes,
|
||||
params->graph_properties, params->output_edge_map,
|
||||
&trt_node_def, params->precision_mode);
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s));
|
||||
tensorflow::Status status;
|
||||
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
|
||||
|
||||
// AddNode does not wire edges.
|
||||
// Re-map incoming edges to use the new TRT node instead of the orig subgraph
|
||||
std::map<std::pair<int, int>, int> subgraph_edge_to_input_map;
|
||||
for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
|
||||
subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
|
||||
}
|
||||
for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
|
||||
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||
int new_src_output = subgraph_edge_to_input_map.at(old_src);
|
||||
params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
|
||||
new_src_output);
|
||||
params->graph.RemoveEdge(edge);
|
||||
}
|
||||
|
||||
VLOG(2) << "new wiring edges: " << trt_node->in_edges().size();
|
||||
for (const tensorflow::Edge* edge : trt_node->in_edges()) {
|
||||
VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
|
||||
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
|
||||
for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
|
||||
subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
|
||||
for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) {
|
||||
subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i});
|
||||
}
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
|
||||
for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) {
|
||||
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||
int new_src_output = subgraph_edge_to_output_map.at(old_src);
|
||||
TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
|
||||
edge->dst_input()));
|
||||
TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
|
||||
trt_node, new_src_output, edge->dst(), edge->dst_input()));
|
||||
}
|
||||
// Remove the original subgraph
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph->FindNodeId(node_id);
|
||||
for (int node_id : params->subgraph_node_ids) {
|
||||
tensorflow::Node* node = params->graph.FindNodeId(node_id);
|
||||
// Don't remove the input placeholders
|
||||
if (node->type_string() == "Placeholder") {
|
||||
continue;
|
||||
}
|
||||
graph->RemoveNode(node);
|
||||
params->graph.RemoveNode(node);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -194,12 +292,39 @@ tensorflow::Status BuildNodeMap(
|
||||
}
|
||||
|
||||
} // namespace
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) {
|
||||
VLOG(0) << "Starting Calib Conversion";
|
||||
tensorflow::Graph graph(tensorflow::OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||
tensorflow::GraphConstructorOptions(), graph_def, &graph));
|
||||
// get calib nodes
|
||||
std::vector<tensorflow::Node*> calib_nodes;
|
||||
for (auto node : graph.op_nodes()) {
|
||||
if (node->type_string() == "TRTCalibOp") {
|
||||
VLOG(1) << "Found Calib Node";
|
||||
calib_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size();
|
||||
if (calib_nodes.size() == 0)
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Graph doesn't contain any calibration nodes!."
|
||||
" Please generate calibration graph and run calibration first");
|
||||
for (auto n : calib_nodes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n));
|
||||
}
|
||||
graph.ToGraphDef(infer_graph);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
|
||||
// Optimization pass
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
|
||||
int precision_mode = FP32MODE, int minimum_segment_size = 3) {
|
||||
// optimization pass
|
||||
tensorflow::grappler::GrapplerItem item;
|
||||
item.fetch = output_names;
|
||||
tensorflow::GraphDef gdef;
|
||||
@ -209,16 +334,23 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
tensorflow::grappler::LayoutOptimizer optimizer;
|
||||
tensorflow::grappler::Cluster* cluster;
|
||||
|
||||
// Virtual cluster
|
||||
// virtual cluster
|
||||
tensorflow::DeviceProperties device_properties;
|
||||
|
||||
device_properties.set_type("GPU");
|
||||
device_properties.mutable_environment()->insert({"architecture", "6"});
|
||||
cluster =
|
||||
new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
|
||||
|
||||
// single machine
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
VLOG(2) << "cpu_cores: " << num_cpu_cores;
|
||||
VLOG(2) << "gpus: " << num_gpus;
|
||||
|
||||
TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
|
||||
|
||||
// Constant folding
|
||||
// constant folding
|
||||
item.graph = gdef;
|
||||
tensorflow::grappler::ConstantFolding fold(nullptr);
|
||||
TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
|
||||
@ -226,7 +358,6 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
// AJ refactoring shape inference through grappler/GraphProperties.
|
||||
tensorflow::grappler::GraphProperties static_graph_properties(item);
|
||||
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
|
||||
|
||||
// Build full graph
|
||||
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||
gdef.library());
|
||||
@ -243,7 +374,7 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
}
|
||||
|
||||
// TODO(sami): this should be passed as a knob!!!!
|
||||
segment_options.minimum_segment_size = 2;
|
||||
segment_options.minimum_segment_size = minimum_segment_size;
|
||||
tensorflow::tensorrt::segment::SegmentNodesVector segments;
|
||||
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
|
||||
gdef, IsTensorRTCandidate, segment_options, &segments));
|
||||
@ -252,14 +383,37 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
}
|
||||
std::unordered_map<string, tensorflow::Node*> node_map;
|
||||
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
|
||||
std::unordered_map<string, std::pair<int, string>> output_edge_map;
|
||||
int count = 0;
|
||||
float total_num_nodes_in_segments = 0.;
|
||||
for (auto s : segments) {
|
||||
total_num_nodes_in_segments += s.size();
|
||||
}
|
||||
for (const std::set<string>& subgraph_node_names : segments) {
|
||||
std::set<int> subgraph_node_ids;
|
||||
size_t max_mem_per_engine =
|
||||
max_workspace_size_bytes *
|
||||
((float)subgraph_node_names.size() / total_num_nodes_in_segments);
|
||||
std::stringstream oss;
|
||||
for (const string& node_name : subgraph_node_names) {
|
||||
oss << " " << node_name;
|
||||
subgraph_node_ids.insert(node_map.at(node_name)->id());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
|
||||
output_names, subgraph_node_ids, max_batch_size,
|
||||
max_workspace_size_bytes, static_graph_properties, &graph));
|
||||
VLOG(2) << "Subgraph nodes" << oss.str();
|
||||
ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size,
|
||||
max_mem_per_engine, static_graph_properties,
|
||||
&output_edge_map, precision_mode);
|
||||
if (precision_mode == INT8MODE) {
|
||||
TF_RETURN_IF_ERROR(GetCalibNode(&p));
|
||||
} else {
|
||||
tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
|
||||
<< " due to: \n"
|
||||
<< status.ToString() << " SKIPPING......";
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
graph.ToGraphDef(new_graph_def);
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -28,6 +28,11 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// This method converts an already generated calibration graph which was used in
|
||||
// calibration runs to an inference graph
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def);
|
||||
|
||||
// max_batch_size: maximum batch size which can be used for inference for
|
||||
// optimization targets inference run with max batch size.
|
||||
// max_workspace_size_bytes: The upper bound of memory allowence for
|
||||
@ -35,7 +40,8 @@ namespace convert {
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def);
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
|
||||
int precision_mode, int minimum_segment_size);
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -32,16 +34,49 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||
const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
|
||||
const std::vector<std::pair<int, int>>&
|
||||
input_inds, // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>&
|
||||
output_inds, // {node_id, output_idx}
|
||||
size_t max_batch_size, size_t max_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& graph_prop,
|
||||
tensorflow::NodeDef* trt_node);
|
||||
const int FP32MODE = 0;
|
||||
const int FP16MODE = 1;
|
||||
const int INT8MODE = 2;
|
||||
|
||||
struct SubGraphParams {
|
||||
SubGraphParams(
|
||||
tensorflow::Graph& inp_graph,
|
||||
const std::set<int>& subgraph_node_id_numbers,
|
||||
const std::vector<std::pair<int, int>>& input_indices,
|
||||
const std::vector<std::pair<int, int>>& output_indices,
|
||||
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& current_graph_properties,
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edges,
|
||||
tensorflow::NodeDef* constructed_trt_node,
|
||||
int engine_precision_mode = FP32MODE)
|
||||
: graph(inp_graph),
|
||||
subgraph_node_ids(subgraph_node_id_numbers),
|
||||
input_inds(input_indices),
|
||||
output_inds(output_indices),
|
||||
max_batch_size(max_supported_batch_size),
|
||||
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
|
||||
graph_properties(current_graph_properties),
|
||||
output_edge_map(output_edges),
|
||||
trt_node(constructed_trt_node),
|
||||
precision_mode(engine_precision_mode) {}
|
||||
|
||||
tensorflow::Graph& graph;
|
||||
const std::set<int>& subgraph_node_ids;
|
||||
const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
const tensorflow::grappler::GraphProperties& graph_properties;
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
|
||||
tensorflow::NodeDef* trt_node;
|
||||
const int precision_mode;
|
||||
};
|
||||
|
||||
// TODO(sami): Replace references with const reference or pointers
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
|
||||
tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
|
||||
tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
|
||||
tensorflow::Node* c_node);
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -21,10 +21,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -113,7 +114,13 @@ void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
|
||||
ctx->set_output(i, t);
|
||||
}
|
||||
VLOG(2) << "Filled map for sending";
|
||||
calib_res->calibrator_->setBatch(input_data);
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
calib_res->calibrator_->setBatch(input_data, *stream);
|
||||
VLOG(2) << "Passed calibration data";
|
||||
// TODO(aaroey): make sure we wait for the completion of calibration on the
|
||||
// last batch in future PR.
|
||||
|
@ -24,8 +24,12 @@ limitations under the License.
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
static ::tensorflow::tensorrt::Logger logger;
|
||||
namespace gpu = ::perftools::gputools;
|
||||
using IRuntime = nvinfer1::IRuntime;
|
||||
using Dims = nvinfer1::Dims;
|
||||
|
||||
namespace tensorrt {
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// read serialized_engine
|
||||
@ -40,10 +44,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
// TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
|
||||
// gpu where the input/output is also located.
|
||||
int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
|
||||
cudaSetDevice(gpu_id);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
if (gpu_id != device) LOG(FATAL) << "set device failed!";
|
||||
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
|
||||
IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
|
||||
serialized_engine.c_str(), serialized_engine.size(), nullptr));
|
||||
|
||||
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
|
||||
// Runtime is safe to delete after engine creation
|
||||
infer->destroy();
|
||||
@ -55,7 +70,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
|
||||
size_t binding_index;
|
||||
int num_batch = 0;
|
||||
bool valid = true;
|
||||
for (int i = 0; i < context->num_inputs(); i++) {
|
||||
// Grab the input tensor
|
||||
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
|
||||
@ -64,8 +78,12 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
if (i == 0) {
|
||||
num_batch = input_shape.dim_size(0);
|
||||
if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
|
||||
LOG(FATAL) << "input tensor batch larger than max_batch_size: "
|
||||
<< trt_engine_ptr_->getMaxBatchSize();
|
||||
}
|
||||
} else if (num_batch != input_shape.dim_size(0)) {
|
||||
valid = false;
|
||||
LOG(FATAL) << "input data inconsistent batch size";
|
||||
break;
|
||||
}
|
||||
switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
|
||||
@ -81,9 +99,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Might want a different way to inform the user of batch size inconsistency
|
||||
if (!valid) LOG(WARNING) << "input data inconsistent batch size";
|
||||
|
||||
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
|
||||
// This is bad that we have to reallocate output buffer every run.
|
||||
// Create an output tensor
|
||||
@ -126,9 +141,11 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
|
||||
// execution handled by TF since we are getting stream from TF.
|
||||
// it is safe for CPU pointer array (buffers) to go out of scope after enqueue
|
||||
trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
|
||||
// TODO(jie): trt enqueue does not return error
|
||||
auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
|
||||
*stream, nullptr);
|
||||
VLOG(2) << "enqueue returns: " << ret;
|
||||
// sync should be done by TF.
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
|
||||
|
@ -27,19 +27,19 @@ void Logger::log(Severity severity, const char* msg) {
|
||||
// Suppress info-level messages
|
||||
switch (severity) {
|
||||
case Severity::kINFO: { // Mark TRT info messages as debug!
|
||||
VLOG(2) << msg;
|
||||
VLOG(2) << name_ << " " << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kWARNING: {
|
||||
LOG(WARNING) << msg;
|
||||
LOG(WARNING) << name_ << " " << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kERROR: {
|
||||
LOG(ERROR) << msg;
|
||||
LOG(ERROR) << name_ << " " << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kINTERNAL_ERROR: {
|
||||
LOG(FATAL) << msg;
|
||||
LOG(FATAL) << name_ << " " << msg;
|
||||
break;
|
||||
}
|
||||
// This is useless for now. But would catch it in future if enum changes. It
|
||||
|
@ -27,9 +27,11 @@ namespace tensorrt {
|
||||
|
||||
// Logger for GIE info/warning/errors
|
||||
class Logger : public nvinfer1::ILogger {
|
||||
private:
|
||||
public:
|
||||
Logger(string name = "DefaultLogger") : name_(name){};
|
||||
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
|
||||
|
||||
private:
|
||||
string name_;
|
||||
};
|
||||
|
||||
|
@ -20,5 +20,6 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
|
||||
from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
|
||||
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
@ -20,11 +20,17 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
import six as _six
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl as _impl
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.util import compat
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
|
||||
# TODO(skama): get outputs from session when implemented as c++
|
||||
@ -32,22 +38,33 @@ from tensorflow.python.framework import ops
|
||||
def create_inference_graph(input_graph_def,
|
||||
outputs,
|
||||
max_batch_size=1,
|
||||
max_workspace_size_bytes=2 << 20):
|
||||
max_workspace_size_bytes=2 << 20,
|
||||
precision_mode="FP32",
|
||||
minimum_segment_size=3):
|
||||
"""Python wrapper for the TRT transormation.
|
||||
|
||||
|
||||
Args:
|
||||
input_graph_def: GraphDef object containing a model to be transformed.
|
||||
outputs: List of tensors or node names for the model outputs.
|
||||
outputs: list of tensors or node names for the model outputs.
|
||||
max_batch_size: max size for the input batch
|
||||
max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
|
||||
precision_mode: one of 'FP32', 'FP16' and 'INT8'
|
||||
minimum_segment_size: the minimum number of nodes required for a subgraph to
|
||||
be replaced by TRTEngineOp.
|
||||
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided precision mode is invalid.
|
||||
RuntimeError: if the returned status message is malformed.
|
||||
"""
|
||||
supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
|
||||
if precision_mode.upper() not in supported_precision_modes:
|
||||
raise ValueError(("precision mode '{}' is not supported."
|
||||
"It should be one of {}").format(
|
||||
precision_mode, "{'FP32', 'FP16', 'INT8'}"))
|
||||
mode = supported_precision_modes[precision_mode.upper()]
|
||||
|
||||
def py2bytes(inp):
|
||||
return inp
|
||||
@ -83,7 +100,7 @@ def create_inference_graph(input_graph_def,
|
||||
# pair or strings where first one is encoded status and the second
|
||||
# one is the transformed graphs protobuf string.
|
||||
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
|
||||
max_workspace_size_bytes)
|
||||
max_workspace_size_bytes, mode, minimum_segment_size)
|
||||
status = to_string(out[0])
|
||||
output_graph_def_string = out[1]
|
||||
del input_graph_def_str # Save some memory
|
||||
@ -101,3 +118,46 @@ def create_inference_graph(input_graph_def,
|
||||
output_graph_def.ParseFromString(output_graph_def_string)
|
||||
del output_graph_def_string # Save some memory
|
||||
return output_graph_def
|
||||
|
||||
|
||||
def calib_graph_to_infer_graph(calibration_graph_def):
|
||||
"""Convert an existing calibration graph to inference graph.
|
||||
|
||||
Args:
|
||||
calibration_graph_def: the calibration GraphDef object with calibration data
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
|
||||
Raises:
|
||||
RuntimeError: if the returned status message is malformed.
|
||||
"""
|
||||
|
||||
def py2string(inp):
|
||||
return inp
|
||||
|
||||
def py3string(inp):
|
||||
return inp.decode("utf-8")
|
||||
|
||||
if _six.PY2:
|
||||
to_string = py2string
|
||||
else:
|
||||
to_string = py3string
|
||||
|
||||
graph_str = calibration_graph_def.SerializeToString()
|
||||
out = calib_convert(graph_str)
|
||||
status = to_string(out[0])
|
||||
output_graph_def_string = out[1]
|
||||
del graph_str # Save some memory
|
||||
if len(status) < 2:
|
||||
raise _impl.UnknownError(None, None, status)
|
||||
if status[:2] != "OK":
|
||||
msg = status.split(";")
|
||||
if len(msg) == 1:
|
||||
raise RuntimeError("Status message is malformed {}".format(status))
|
||||
# pylint: disable=protected-access
|
||||
raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
|
||||
int(msg[0]))
|
||||
# pylint: enable=protected-access
|
||||
output_graph_def = graph_pb2.GraphDef()
|
||||
output_graph_def.ParseFromString(output_graph_def_string)
|
||||
del output_graph_def_string # Save some memory
|
||||
return output_graph_def
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
@ -38,22 +38,18 @@ TRTInt8Calibrator::TRTInt8Calibrator(
|
||||
done_(false),
|
||||
dev_buffers_(dev_buffers),
|
||||
calib_running_(false),
|
||||
batch_is_set_(false),
|
||||
engine_name_(engine_name) {}
|
||||
|
||||
bool TRTInt8Calibrator::setBatch(
|
||||
const std::unordered_map<string, void*>& data) {
|
||||
// TODO(aaroey): make sure that in future PR:
|
||||
// 1. the mutex_lock is outside of the loop
|
||||
// 2. wait() is used instead of wait_for()
|
||||
// 3. done_ is to be protected by the mutex
|
||||
// 4. the first batch is not missed
|
||||
if (done_) return false;
|
||||
while (calib_running_.load(
|
||||
std::memory_order_acquire)) { // wait while calibration is running
|
||||
tensorflow::mutex_lock l(cond_mtx_);
|
||||
cond_.wait_for(l, std::chrono::milliseconds(50));
|
||||
if (done_) return false;
|
||||
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
|
||||
const cudaStream_t stream) {
|
||||
tensorflow::mutex_lock lock(cond_mtx_);
|
||||
while ((calib_running_ || batch_is_set_) &&
|
||||
!done_) { // wait while calibration is running
|
||||
cond_.wait(lock);
|
||||
}
|
||||
if (done_) return false;
|
||||
CHECK(!calib_running_ && !batch_is_set_);
|
||||
VLOG(1) << "Set Batch Waiting finished";
|
||||
for (const auto it : data) {
|
||||
auto devptr = dev_buffers_.find(it.first);
|
||||
@ -65,27 +61,32 @@ bool TRTInt8Calibrator::setBatch(
|
||||
|
||||
// TODO(aaroey): we should not use sync copy on default stream. Make sure
|
||||
// stream->ThenMemcpy() is used in future PRs.
|
||||
auto status =
|
||||
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice);
|
||||
// TODO(sami,aaroey): Need to figureout a way to ensure synchronization
|
||||
// between stream, perhaps using a tensor?
|
||||
auto status = cudaMemcpyAsync(d.first, it.second, d.second,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
if (status != cudaSuccess) {
|
||||
LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
|
||||
<< "' failed with " << status;
|
||||
}
|
||||
}
|
||||
calib_running_.store(true, std::memory_order_release); // release builder
|
||||
|
||||
// TODO(Sami, aaorey): Find an alternative way!
|
||||
cudaStreamSynchronize(
|
||||
stream); // we have to wait for the stream before returning!
|
||||
batch_is_set_ = true;
|
||||
cond_.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||
int num_bindings) {
|
||||
calib_running_.store(false, std::memory_order_release); // wait for new batch
|
||||
tensorflow::mutex_lock lock(cond_mtx_);
|
||||
calib_running_ = false;
|
||||
cond_.notify_all();
|
||||
while (!calib_running_.load(
|
||||
std::memory_order_acquire)) { // wait until new batch arrives
|
||||
tensorflow::mutex_lock l(cond_mtx_);
|
||||
cond_.wait_for(l, std::chrono::milliseconds(50));
|
||||
if (done_) return false;
|
||||
while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
|
||||
cond_.wait(lock);
|
||||
|
||||
}
|
||||
if (done_) {
|
||||
return false;
|
||||
@ -100,6 +101,8 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||
|
||||
bindings[i] = it->second.first;
|
||||
}
|
||||
batch_is_set_ = false;
|
||||
calib_running_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -107,6 +110,12 @@ const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::setDone() {
|
||||
tensorflow::mutex_lock lock(cond_mtx_);
|
||||
done_ = true;
|
||||
cond_.notify_all();
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
|
||||
std::size_t length) {}
|
||||
TRTInt8Calibrator::~TRTInt8Calibrator() {
|
||||
@ -115,5 +124,6 @@ TRTInt8Calibrator::~TRTInt8Calibrator() {
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -24,7 +24,10 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
// This class provides a 1 element queue to match TFs push model to
|
||||
@ -39,8 +42,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
|
||||
int getBatchSize() const override;
|
||||
bool getBatch(void* bindings[], const char* names[],
|
||||
int num_bindings) override;
|
||||
bool setBatch(const std::unordered_map<string, void*>& data);
|
||||
void setDone() { done_ = true; }
|
||||
bool setBatch(const std::unordered_map<string, void*>& data,
|
||||
const cudaStream_t stream);
|
||||
void setDone();
|
||||
const void* readCalibrationCache(std::size_t& length) override;
|
||||
void writeCalibrationCache(const void* ptr, std::size_t length) override;
|
||||
~TRTInt8Calibrator();
|
||||
@ -55,11 +59,14 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
|
||||
const std::unordered_map<string, std::pair<void*, size_t>>
|
||||
dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with
|
||||
// buffer names
|
||||
std::atomic_bool calib_running_;
|
||||
bool calib_running_;
|
||||
bool batch_is_set_;
|
||||
string engine_name_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_
|
||||
#endif
|
||||
#endif
|
||||
|
@ -60,6 +60,7 @@ def get_simple_graph_def():
|
||||
|
||||
|
||||
def run_graph(gdef, dumm_inp):
|
||||
"""Run given graphdef once."""
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
ops.reset_default_graph()
|
||||
g = ops.Graph()
|
||||
@ -74,15 +75,65 @@ def run_graph(gdef, dumm_inp):
|
||||
return val
|
||||
|
||||
|
||||
# Use real data that is representatitive of the inference dataset
|
||||
# for calibration. For this test script it is random data.
|
||||
def run_calibration(gdef, dumm_inp):
|
||||
"""Run given calibration graph multiple times."""
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
ops.reset_default_graph()
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
inp, out = importer.import_graph_def(
|
||||
graph_def=gdef, return_elements=["input", "output"])
|
||||
inp = inp.outputs[0]
|
||||
out = out.outputs[0]
|
||||
with csess.Session(
|
||||
config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
|
||||
# run over real calibration data here, we are mimicking a calibration set of
|
||||
# 30 different batches. Use as much calibration data as you want
|
||||
for _ in range(30):
|
||||
val = sess.run(out, {inp: dumm_inp})
|
||||
return val
|
||||
|
||||
|
||||
if "__main__" in __name__:
|
||||
inp_dims = (100, 24, 24, 2)
|
||||
dummy_input = np.random.random_sample(inp_dims)
|
||||
gdef = get_simple_graph_def()
|
||||
orig_graph = get_simple_graph_def() # use a frozen graph for inference
|
||||
# Get optimized graph
|
||||
trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
|
||||
o1 = run_graph(gdef, dummy_input)
|
||||
trt_graph = trt.create_inference_graph(
|
||||
input_graph_def=orig_graph,
|
||||
outputs=["output"],
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
o1 = run_graph(orig_graph, dummy_input)
|
||||
o2 = run_graph(trt_graph, dummy_input)
|
||||
o3 = run_graph(trt_graph, dummy_input)
|
||||
assert np.array_equal(o1, o2)
|
||||
assert np.array_equal(o3, o2) # sanity check
|
||||
fp16_graph = trt.create_inference_graph(
|
||||
input_graph_def=orig_graph,
|
||||
outputs=["output"],
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
int8_calib_gdef = trt.create_inference_graph(
|
||||
input_graph_def=orig_graph,
|
||||
outputs=["output"],
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
o4 = run_graph(fp16_graph, dummy_input)
|
||||
_ = run_calibration(int8_calib_gdef, dummy_input)
|
||||
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
|
||||
o5 = run_graph(int8_graph, dummy_input)
|
||||
assert np.allclose(o1, o4)
|
||||
assert np.allclose(o1, o5)
|
||||
print("Pass")
|
||||
|
@ -64,13 +64,17 @@ PyObject* pair_helper(std::pair<string, string>* in) {
|
||||
%ignoreall
|
||||
%unignore tensorflow;
|
||||
%unignore trt_convert;
|
||||
%unignore calib_convert;
|
||||
|
||||
%{
|
||||
|
||||
std::pair<string, string> trt_convert(
|
||||
string graph_def_string, // The serialized GraphDef string.
|
||||
std::vector<string> output_names,
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes
|
||||
size_t max_workspace_size_bytes,
|
||||
int precision_mode,
|
||||
int minimum_segment_size
|
||||
// Unfortunately we can't use TF_Status here since it
|
||||
// is in c/c_api and brings in a lot of other libraries
|
||||
// which in turn declare ops. These ops are included
|
||||
@ -90,16 +94,64 @@ std::pair<string, string> trt_convert(
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
|
||||
if(precision_mode < 0 || precision_mode > 2){
|
||||
out_status = "InvalidArgument;Invalid precision_mode";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
if (!output_names.size()) {
|
||||
out_status = "InvalidArgument;Size of the output_names vector is 0";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
// return "";
|
||||
}
|
||||
tensorflow::GraphDef outGraph;
|
||||
tensorflow::Status conversion_status =
|
||||
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
|
||||
graph_def, output_names, max_batch_size, max_workspace_size_bytes,
|
||||
&outGraph);
|
||||
&outGraph, precision_mode, minimum_segment_size);
|
||||
if (!conversion_status.ok()) {
|
||||
auto retCode = (int)conversion_status.code();
|
||||
char buff[2000];
|
||||
snprintf(buff, 2000, "%d;%s", retCode,
|
||||
conversion_status.error_message().c_str());
|
||||
out_status = buff;
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
string result;
|
||||
if (!outGraph.SerializeToString(&result)) {
|
||||
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
out_status = "OK;All good!";
|
||||
return std::pair<string, string>{out_status, result};
|
||||
#else
|
||||
// Returns FAILED_PRECONDITION.
|
||||
return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
}
|
||||
|
||||
std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef&
|
||||
// unfortunately we can't use TF_Status here since it
|
||||
// is in c/c_api and brings in a lot of other libraries
|
||||
// which in turn declare ops. These ops are included
|
||||
// statically in our library and cause an abort when
|
||||
// module is loaded due to double registration
|
||||
// until Tensorflow properly exposes these headers
|
||||
// we have to work around this by returning a string
|
||||
// and converting it to exception on python side.
|
||||
//,TF_Status* out_status) {
|
||||
) {
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
string out_status;
|
||||
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||
out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
|
||||
tensorflow::GraphDef outGraph;
|
||||
tensorflow::Status conversion_status =
|
||||
tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def,
|
||||
&outGraph);
|
||||
if (!conversion_status.ok()) {
|
||||
auto retCode = (int)conversion_status.code();
|
||||
char buff[2000];
|
||||
@ -122,10 +174,13 @@ std::pair<string, string> trt_convert(
|
||||
}
|
||||
%}
|
||||
|
||||
std::pair<string, string> calib_convert(string graph_def_string);
|
||||
|
||||
std::pair<string, string> trt_convert(string graph_def_string,
|
||||
std::vector<string> output_names,
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes);
|
||||
size_t max_workspace_size_bytes,
|
||||
int precision_mode, int minimum_segment_size);
|
||||
|
||||
|
||||
%unignoreall
|
||||
|
@ -25,7 +25,10 @@ py_test(
|
||||
srcs = ["predict_test.py"],
|
||||
data = ["data/period_trend.csv"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"], # b/67513579
|
||||
tags = [
|
||||
"no_windows", # TODO: needs investigation on Windows
|
||||
"notsan", # b/67513579
|
||||
],
|
||||
deps = [
|
||||
":predict",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -156,9 +156,7 @@ py_test(
|
||||
"head_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_pip_gpu", # b/63391119
|
||||
],
|
||||
tags = ["no_pip_gpu"], # b/63391119
|
||||
deps = [
|
||||
":feature_keys",
|
||||
":head",
|
||||
@ -427,6 +425,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_pip_gpu", # b/63391119
|
||||
"no_windows", # TODO: needs investigation on Windows
|
||||
],
|
||||
deps = [
|
||||
":feature_keys",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user