Merge changes from github.
PiperOrigin-RevId: 182258809
This commit is contained in:
parent
7699ea8bee
commit
d8697935d3
@ -31,6 +31,10 @@ tracking requests and bugs. So please see
|
||||
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions
|
||||
and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
|
||||
|
||||
The TensorFlow project strives to abide by generally accepted best practices in open-source software development:
|
||||
|
||||
[data:image/s3,"s3://crabby-images/c6c36/c6c362bfbe92810e99b0144b17b487ae71a2edc6" alt="CII Best Practices"](https://bestpractices.coreinfrastructure.org/projects/1486)
|
||||
|
||||
## Installation
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
|
||||
|
||||
|
@ -364,6 +364,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "override_eigen_strong_inline",
|
||||
values = {"define": "override_eigen_strong_inline=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
|
@ -67,7 +67,7 @@ class AllocationTracker {
|
||||
// The device that the memory is allocated on.
|
||||
int device_ordinal;
|
||||
|
||||
// This is the number of times this memory allocation is refered to by
|
||||
// This is the number of times this memory allocation is referred to by
|
||||
// registered data handles.
|
||||
int ref_count;
|
||||
};
|
||||
|
@ -176,7 +176,7 @@ class ColumnMajorMatrixVectorProductEmitter {
|
||||
}
|
||||
|
||||
// Load a tile of values from the RHS. For the RHS a "tile" is a contiguous
|
||||
// sequnce of `count` values, each one broadcasted to the vector width.
|
||||
// sequence of `count` values, each one broadcasted to the vector width.
|
||||
std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
|
||||
llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
|
||||
std::vector<llvm::Value*> result;
|
||||
|
@ -62,7 +62,7 @@ namespace {
|
||||
// &tagged_instructions));
|
||||
//
|
||||
// Instructions that are "tagged" with a context-specific string will
|
||||
// be returned in 'tagged_instructions' for further procesing (i.e. parsing
|
||||
// be returned in 'tagged_instructions' for further processing (i.e. parsing
|
||||
// constants or recording the tuple_index).
|
||||
//
|
||||
class ExprTree {
|
||||
|
@ -296,7 +296,7 @@ class LayoutAssignment : public HloPassInterface {
|
||||
const ResultLayoutConstraint& layout_constraint,
|
||||
LayoutConstraints* constraints);
|
||||
|
||||
// By default LayoutAssignment ensures that inputs and ouptuts of CustomCalls
|
||||
// By default LayoutAssignment ensures that inputs and outputs of CustomCalls
|
||||
// have the "major-first" layout (i.e. {n, n-1, ..., 0}).
|
||||
//
|
||||
// If this function returns true, LayoutAssignment does not set a layout for
|
||||
|
@ -51,7 +51,7 @@ class IndicesRowIterator
|
||||
return tmp;
|
||||
}
|
||||
|
||||
reference operator*() { return iter_->ix()(row_idx_, 0); }
|
||||
reference operator*() const { return iter_->ix()(row_idx_, 0); }
|
||||
|
||||
pointer operator->() { return &iter_->ix()(row_idx_, 0); }
|
||||
|
||||
|
@ -39,11 +39,7 @@ ExternalProject_Add(boringssl
|
||||
# BUILD_IN_SOURCE 1
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
)
|
||||
|
@ -42,11 +42,7 @@ ExternalProject_Add(jsoncpp
|
||||
BUILD_IN_SOURCE 1
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
)
|
||||
|
6
tensorflow/contrib/cmake/external/lmdb.cmake
vendored
6
tensorflow/contrib/cmake/external/lmdb.cmake
vendored
@ -29,11 +29,7 @@ ExternalProject_Add(lmdb
|
||||
INSTALL_DIR ${lmdb_INSTALL}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL}
|
||||
|
6
tensorflow/contrib/cmake/external/png.cmake
vendored
6
tensorflow/contrib/cmake/external/png.cmake
vendored
@ -41,11 +41,7 @@ ExternalProject_Add(png
|
||||
INSTALL_DIR ${png_INSTALL}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL}
|
||||
|
@ -44,11 +44,7 @@ ExternalProject_Add(protobuf
|
||||
${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS}
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DZLIB_ROOT:STRING=${ZLIB_INSTALL}
|
||||
|
6
tensorflow/contrib/cmake/external/re2.cmake
vendored
6
tensorflow/contrib/cmake/external/re2.cmake
vendored
@ -38,11 +38,7 @@ ExternalProject_Add(re2
|
||||
BUILD_IN_SOURCE 1
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL}
|
||||
-DRE2_BUILD_TESTING:BOOL=OFF
|
||||
|
@ -40,11 +40,7 @@ ExternalProject_Add(snappy
|
||||
LOG_CONFIGURE ON
|
||||
LOG_BUILD ON
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DSNAPPY_BUILD_TESTS:BOOL=OFF
|
||||
|
@ -54,11 +54,7 @@ else()
|
||||
INSTALL_DIR ${sqlite_INSTALL}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL}
|
||||
|
6
tensorflow/contrib/cmake/external/zlib.cmake
vendored
6
tensorflow/contrib/cmake/external/zlib.cmake
vendored
@ -42,11 +42,7 @@ ExternalProject_Add(zlib
|
||||
BUILD_IN_SOURCE 1
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
CMAKE_CACHE_ARGS
|
||||
if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE)
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
else()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF
|
||||
endif()
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}
|
||||
)
|
||||
|
@ -1354,7 +1354,7 @@ class _CudnnRNN(object):
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference.
|
||||
Returns:
|
||||
output: the output sequuence.
|
||||
output: the output sequence.
|
||||
output_h: the final state for h.
|
||||
output_c: the final state for c. This is only relevant for LSTM.
|
||||
"""
|
||||
@ -1472,7 +1472,7 @@ class CudnnLSTM(_CudnnRNN):
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference.
|
||||
Returns:
|
||||
output: the output sequuence.
|
||||
output: the output sequence.
|
||||
output_h: the final state for h.
|
||||
output_c: the final state for c.
|
||||
"""
|
||||
|
@ -185,8 +185,8 @@ py_test(
|
||||
srcs = ["interleave_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
":dataset_serialization_test",
|
||||
|
@ -214,7 +214,7 @@ class Datasets(object):
|
||||
"""Load the Penn Treebank dataset.
|
||||
|
||||
Args:
|
||||
path: Path to the data/ directory of the dataset from from Tomas Mikolov's
|
||||
path: Path to the data/ directory of the dataset from Tomas Mikolov's
|
||||
webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
"""
|
||||
|
||||
|
@ -35,6 +35,7 @@ tf_custom_op_py_library(
|
||||
"python/ops/critical_section_ops.py",
|
||||
"python/ops/ops.py",
|
||||
"python/ops/prettyprint_ops.py",
|
||||
"python/ops/script_ops.py",
|
||||
"python/ops/sort_ops.py",
|
||||
"python/ops/variables.py",
|
||||
],
|
||||
@ -62,6 +63,7 @@ tf_custom_op_py_library(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:state_ops_gen",
|
||||
|
@ -81,6 +81,7 @@ See the @{$python/contrib.framework} guide.
|
||||
@@load_linear_multiclass_bias_initializer
|
||||
@@load_variable_slot_initializer
|
||||
|
||||
@@py_func
|
||||
@@sort
|
||||
|
||||
@@CriticalSection
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.framework.python.ops.checkpoint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.critical_section_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.ops import *
|
||||
from tensorflow.contrib.framework.python.ops.prettyprint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.script_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.sort_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.variables import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
142
tensorflow/contrib/framework/python/ops/script_ops.py
Normal file
142
tensorflow/contrib/framework/python/ops/script_ops.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Script Language Operators. See the @{$python/script_ops} guide.
|
||||
|
||||
@@py_func
|
||||
"""
|
||||
|
||||
# pylint: disable=g-bad-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops.script_ops import py_func as _py_func
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
__all__ = ['py_func']
|
||||
|
||||
|
||||
def py_func(func,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
output_types=None,
|
||||
output_shapes=None,
|
||||
stateful=True,
|
||||
name=None):
|
||||
"""Wraps a python function and uses it as a TensorFlow op.
|
||||
|
||||
This function is a wrapper around `tf.py_func` and improve it with kwargs
|
||||
and output_shapes. Further it changed some argument names.
|
||||
|
||||
Given a python function `func`, which takes numpy arrays as its
|
||||
inputs and returns numpy arrays as its outputs, wrap this function as an
|
||||
operation in a TensorFlow graph. The following snippet constructs a simple
|
||||
TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
|
||||
in the graph:
|
||||
|
||||
```python
|
||||
def my_func(x):
|
||||
# x will be a numpy array with the contents of the placeholder below
|
||||
return np.sinh(x)
|
||||
inp = tf.placeholder(tf.float32)
|
||||
y = tf.py_func(my_func, [inp], tf.float32)
|
||||
```
|
||||
|
||||
|
||||
**N.B.** The `tf.py_func()` operation has the following known limitations:
|
||||
|
||||
* The body of the function (i.e. `func`) will not be serialized in a
|
||||
`GraphDef`. Therefore, you should not use this function if you need to
|
||||
serialize your model and restore it in a different environment.
|
||||
|
||||
* The operation must run in the same address space as the Python program
|
||||
that calls `tf.py_func()`. If you are using distributed TensorFlow, you
|
||||
must run a `tf.train.Server` in the same process as the program that calls
|
||||
`tf.py_func()` and you must pin the created operation to a device in that
|
||||
server (e.g. using `with tf.device():`).
|
||||
|
||||
Args:
|
||||
func: A Python function, which accepts a list of NumPy `ndarray` objects
|
||||
having element types that match the corresponding `tf.Tensor` objects
|
||||
in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`)
|
||||
having element types that match the corresponding values in `Tout`.
|
||||
args: A list of `Tensor` objects.
|
||||
kwargs: A dict with `Tensor` objects as values.
|
||||
output_types: A nested structure of tensorflow data types or a single
|
||||
tensorflow data type if there is only one, indicating what `func` returns.
|
||||
output_shapes: Same as output_types, except the types are replaces with
|
||||
shapes (optional).
|
||||
stateful: (Boolean.) If True, the function should be considered stateful.
|
||||
If a function is stateless, when given the same input it will return the
|
||||
same output and have no observable side effects. Optimizations such as
|
||||
common subexpression elimination are only performed on stateless
|
||||
operations.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
Tensorflow op that wraps the input python function.
|
||||
"""
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if not isinstance(args, (list, tuple)):
|
||||
raise TypeError('args must be list and not {}. args: {}'.format(
|
||||
type(args), args))
|
||||
|
||||
if not isinstance(kwargs, dict):
|
||||
raise TypeError('kwargs must be dict and not {}. args: {}'.format(
|
||||
type(kwargs), kwargs))
|
||||
|
||||
# For dynamic type inference use callable output_types and output_shapes
|
||||
if callable(output_types):
|
||||
# If callable assume same signature and call with tensors and get the types
|
||||
output_types = output_types(*args, **kwargs)
|
||||
if callable(output_shapes):
|
||||
# If callable assume same signature and call with tensors and get the shapes
|
||||
output_shapes = output_shapes(*args, **kwargs)
|
||||
|
||||
flat_output_types = nest.flatten(output_types)
|
||||
args = (args, kwargs)
|
||||
flat_args = nest.flatten(args)
|
||||
|
||||
def python_function_wrapper(*py_args):
|
||||
py_args, py_kwargs = nest.pack_sequence_as(args, py_args)
|
||||
|
||||
ret = func(*py_args, **py_kwargs)
|
||||
# TODO(alextp): Catch Exceptions and improve msg, because tensorflow
|
||||
# ist not able to preserve the traceback, i.e. the Exceptions does not
|
||||
# contain any information where the Exception was raised.
|
||||
nest.assert_shallow_structure(output_types, ret)
|
||||
return nest.flatten(ret)
|
||||
|
||||
flat_values = _py_func(
|
||||
python_function_wrapper,
|
||||
flat_args,
|
||||
flat_output_types,
|
||||
stateful=stateful,
|
||||
name=name)
|
||||
|
||||
if output_shapes is not None:
|
||||
# I am not sure if this is nessesary
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
|
||||
flattened_shapes = nest.flatten(output_shapes)
|
||||
for ret_t, shape in zip(flat_values, flattened_shapes):
|
||||
ret_t.set_shape(shape)
|
||||
|
||||
return nest.pack_sequence_as(output_types, flat_values)
|
@ -343,7 +343,7 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn):
|
||||
`tensor_pool_fn` is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor pool does not suport the `model`.
|
||||
ValueError: If tensor pool does not support the `model`.
|
||||
"""
|
||||
if tensor_pool_fn is None:
|
||||
return model
|
||||
|
75
tensorflow/contrib/lite/examples/label_image/BUILD
Normal file
75
tensorflow/contrib/lite/examples/label_image/BUILD
Normal file
@ -0,0 +1,75 @@
|
||||
# Description:
|
||||
# TensorFlow Lite Example Label Image.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
|
||||
|
||||
exports_files(glob([
|
||||
"testdata/*.bmp",
|
||||
]))
|
||||
|
||||
tf_cc_binary(
|
||||
name = "label_image",
|
||||
srcs = [
|
||||
"get_top_n.h",
|
||||
"get_top_n_impl.h",
|
||||
"label_image.cc",
|
||||
],
|
||||
linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":bitmap_helpers",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bitmap_helpers",
|
||||
srcs = ["bitmap_helpers.cc"],
|
||||
hdrs = [
|
||||
"bitmap_helpers.h",
|
||||
"bitmap_helpers_impl.h",
|
||||
"label_image.h",
|
||||
],
|
||||
deps = ["//tensorflow/contrib/lite:string"],
|
||||
)
|
||||
|
||||
# TODO(ahentz): Test disabled as it has a memory leek from read_bmp
|
||||
# cc_test(
|
||||
# name = "label_image_test",
|
||||
# srcs = [
|
||||
# "get_top_n.h",
|
||||
# "get_top_n_impl.h",
|
||||
# "label_image_test.cc",
|
||||
# ],
|
||||
# data = [
|
||||
# "testdata/grace_hopper.bmp",
|
||||
# ],
|
||||
# deps = [
|
||||
# ":bitmap_helpers",
|
||||
# "//testing/base/public:gunit",
|
||||
# ],
|
||||
# )
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
120
tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
Normal file
120
tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include <unistd.h> // NOLINT(build/include_order)
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h"
|
||||
|
||||
#define LOG(x) std::cerr
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output,
|
||||
int width, int height, int channels, bool top_down) {
|
||||
for (int i = 0; i < height; i++) {
|
||||
int src_pos;
|
||||
int dst_pos;
|
||||
|
||||
for (int j = 0; j < width; j++) {
|
||||
if (!top_down) {
|
||||
src_pos = ((height - 1 - i) * row_size) + j * channels;
|
||||
} else {
|
||||
src_pos = i * row_size + j * channels;
|
||||
}
|
||||
|
||||
dst_pos = (i * width + j) * channels;
|
||||
|
||||
switch (channels) {
|
||||
case 1:
|
||||
output[dst_pos] = input[src_pos];
|
||||
break;
|
||||
case 3:
|
||||
// BGR -> RGB
|
||||
output[dst_pos] = input[src_pos + 2];
|
||||
output[dst_pos + 1] = input[src_pos + 1];
|
||||
output[dst_pos + 2] = input[src_pos];
|
||||
break;
|
||||
case 4:
|
||||
// BGRA -> RGBA
|
||||
output[dst_pos] = input[src_pos + 2];
|
||||
output[dst_pos + 1] = input[src_pos + 1];
|
||||
output[dst_pos + 2] = input[src_pos];
|
||||
output[dst_pos + 3] = input[src_pos + 3];
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected number of channels: " << channels;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
|
||||
int* channels, Settings* s) {
|
||||
int begin, end;
|
||||
|
||||
std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
|
||||
if (!file) {
|
||||
LOG(FATAL) << "input file " << input_bmp_name << " not found\n";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
begin = file.tellg();
|
||||
file.seekg(0, std::ios::end);
|
||||
end = file.tellg();
|
||||
size_t len = end - begin;
|
||||
|
||||
if (s->verbose) LOG(INFO) << "len: " << len << "\n";
|
||||
|
||||
const uint8_t* img_bytes = new uint8_t[len];
|
||||
file.seekg(0, std::ios::beg);
|
||||
file.read((char*)img_bytes, len);
|
||||
const int32_t header_size =
|
||||
*(reinterpret_cast<const int32_t*>(img_bytes + 10));
|
||||
*width = *(reinterpret_cast<const int32_t*>(img_bytes + 18));
|
||||
*height = *(reinterpret_cast<const int32_t*>(img_bytes + 22));
|
||||
const int32_t bpp = *(reinterpret_cast<const int32_t*>(img_bytes + 28));
|
||||
*channels = bpp / 8;
|
||||
|
||||
if (s->verbose)
|
||||
LOG(INFO) << "width, height, channels: " << *width << ", " << *height
|
||||
<< ", " << *channels << "\n";
|
||||
|
||||
// there may be padding bytes when the width is not a multiple of 4 bytes
|
||||
// 8 * channels == bits per pixel
|
||||
const int row_size = (8 * *channels * *width + 31) / 32 * 4;
|
||||
|
||||
// if height is negative, data layout is top down
|
||||
// otherwise, it's bottom up
|
||||
bool top_down = (*height < 0);
|
||||
|
||||
// Decode image, allocating tensor once the image size is known
|
||||
uint8_t* output = new uint8_t[abs(*height) * *width * *channels];
|
||||
const uint8_t* bmp_pixels = &img_bytes[header_size];
|
||||
return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height),
|
||||
*channels, top_down);
|
||||
}
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h"
|
||||
#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
|
||||
int* channels, Settings* s);
|
||||
|
||||
template <class T>
|
||||
void downsize(T* out, uint8_t* in, int image_height, int image_width,
|
||||
int image_channels, int wanted_height, int wanted_width,
|
||||
int wanted_channels, Settings* s);
|
||||
|
||||
// explicit instantiation
|
||||
template void downsize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int,
|
||||
int, int, Settings*);
|
||||
template void downsize<float>(float*, unsigned char*, int, int, int, int, int,
|
||||
int, Settings*);
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
template <class T>
|
||||
void downsize(T* out, uint8_t* in, int image_height, int image_width,
|
||||
int image_channels, int wanted_height, int wanted_width,
|
||||
int wanted_channels, Settings* s) {
|
||||
for (int y = 0; y < wanted_height; ++y) {
|
||||
const int in_y = (y * image_height) / wanted_height;
|
||||
uint8_t* in_row = in + (in_y * image_width * image_channels);
|
||||
T* out_row = out + (y * wanted_width * wanted_channels);
|
||||
for (int x = 0; x < wanted_width; ++x) {
|
||||
const int in_x = (x * image_width) / wanted_width;
|
||||
uint8_t* in_pixel = in_row + (in_x * image_channels);
|
||||
T* out_pixel = out_row + (x * wanted_channels);
|
||||
for (int c = 0; c < wanted_channels; ++c) {
|
||||
if (s->input_floating)
|
||||
out_pixel[c] = (in_pixel[c] - s->input_mean) / s->input_std;
|
||||
else
|
||||
out_pixel[c] = in_pixel[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
|
38
tensorflow/contrib/lite/examples/label_image/get_top_n.h
Normal file
38
tensorflow/contrib/lite/examples/label_image/get_top_n.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
template <class T>
|
||||
void get_top_n(T* prediction, int prediction_size, size_t num_results,
|
||||
float threshold, std::vector<std::pair<float, int>>* top_results,
|
||||
bool input_floating);
|
||||
|
||||
// explicit instantiation so that we can use them otherwhere
|
||||
template void get_top_n<uint8_t>(uint8_t*, int, size_t, float,
|
||||
std::vector<std::pair<float, int>>*, bool);
|
||||
template void get_top_n<float>(float*, int, size_t, float,
|
||||
std::vector<std::pair<float, int>>*, bool);
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
|
@ -0,0 +1,70 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
extern bool input_floating;
|
||||
|
||||
// Returns the top N confidence values over threshold in the provided vector,
|
||||
// sorted by confidence in descending order.
|
||||
template <class T>
|
||||
void get_top_n(T* prediction, int prediction_size, size_t num_results,
|
||||
float threshold, std::vector<std::pair<float, int>>* top_results,
|
||||
bool input_floating) {
|
||||
// Will contain top N results in ascending order.
|
||||
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
|
||||
std::greater<std::pair<float, int>>>
|
||||
top_result_pq;
|
||||
|
||||
const long count = prediction_size; // NOLINT(runtime/int)
|
||||
for (int i = 0; i < count; ++i) {
|
||||
float value;
|
||||
if (input_floating)
|
||||
value = prediction[i];
|
||||
else
|
||||
value = prediction[i] / 255.0;
|
||||
// Only add it if it beats the threshold and has a chance at being in
|
||||
// the top N.
|
||||
if (value < threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
top_result_pq.push(std::pair<float, int>(value, i));
|
||||
|
||||
// If at capacity, kick the smallest value out.
|
||||
if (top_result_pq.size() > num_results) {
|
||||
top_result_pq.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Copy to output vector and reverse into descending order.
|
||||
while (!top_result_pq.empty()) {
|
||||
top_results->push_back(top_result_pq.top());
|
||||
top_result_pq.pop();
|
||||
}
|
||||
std::reverse(top_results->begin(), top_results->end());
|
||||
}
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
|
300
tensorflow/contrib/lite/examples/label_image/label_image.cc
Normal file
300
tensorflow/contrib/lite/examples/label_image/label_image.cc
Normal file
@ -0,0 +1,300 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <fcntl.h> // NOLINT(build/include_order)
|
||||
#include <getopt.h> // NOLINT(build/include_order)
|
||||
#include <sys/time.h> // NOLINT(build/include_order)
|
||||
#include <sys/types.h> // NOLINT(build/include_order)
|
||||
#include <sys/uio.h> // NOLINT(build/include_order)
|
||||
#include <unistd.h> // NOLINT(build/include_order)
|
||||
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/optional_debug_tools.h"
|
||||
#include "tensorflow/contrib/lite/string_util.h"
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h"
|
||||
#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h"
|
||||
|
||||
#define LOG(x) std::cerr
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
|
||||
|
||||
// Takes a file name, and loads a list of labels from it, one per line, and
|
||||
// returns a vector of the strings. It pads with empty strings so the length
|
||||
// of the result is a multiple of 16, because our model expects that.
|
||||
TfLiteStatus ReadLabelsFile(const string& file_name,
|
||||
std::vector<string>* result,
|
||||
size_t* found_label_count) {
|
||||
std::ifstream file(file_name);
|
||||
if (!file) {
|
||||
LOG(FATAL) << "Labels file " << file_name << " not found\n";
|
||||
return kTfLiteError;
|
||||
}
|
||||
result->clear();
|
||||
string line;
|
||||
while (std::getline(file, line)) {
|
||||
result->push_back(line);
|
||||
}
|
||||
*found_label_count = result->size();
|
||||
const int padding = 16;
|
||||
while (result->size() % padding) {
|
||||
result->emplace_back();
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void RunInference(Settings* s) {
|
||||
if (!s->model_name.c_str()) {
|
||||
LOG(ERROR) << "no model file name\n";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
|
||||
if (!model) {
|
||||
LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
|
||||
exit(-1);
|
||||
}
|
||||
LOG(INFO) << "Loaded model " << s->model_name << "\n";
|
||||
model->error_reporter();
|
||||
LOG(INFO) << "resolved reporter\n";
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
|
||||
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
|
||||
if (!interpreter) {
|
||||
LOG(FATAL) << "Failed to construct interpreter\n";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
interpreter->UseNNAPI(s->accel);
|
||||
|
||||
if (s->verbose) {
|
||||
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
|
||||
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
|
||||
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
|
||||
LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";
|
||||
|
||||
int t_size = interpreter->tensors_size();
|
||||
for (int i = 0; i < t_size; i++) {
|
||||
if (interpreter->tensor(i)->name)
|
||||
LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
|
||||
<< interpreter->tensor(i)->bytes << ", "
|
||||
<< interpreter->tensor(i)->type << ", "
|
||||
<< interpreter->tensor(i)->params.scale << ", "
|
||||
<< interpreter->tensor(i)->params.zero_point << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (s->number_of_threads != -1) {
|
||||
interpreter->SetNumThreads(s->number_of_threads);
|
||||
}
|
||||
|
||||
int image_width = 224;
|
||||
int image_height = 224;
|
||||
int image_channels = 3;
|
||||
uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
|
||||
&image_channels, s);
|
||||
|
||||
int input = interpreter->inputs()[0];
|
||||
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
|
||||
|
||||
const std::vector<int> inputs = interpreter->inputs();
|
||||
const std::vector<int> outputs = interpreter->outputs();
|
||||
|
||||
if (s->verbose) {
|
||||
LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
|
||||
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
|
||||
}
|
||||
|
||||
if (interpreter->AllocateTensors() != kTfLiteOk) {
|
||||
LOG(FATAL) << "Failed to allocate tensors!";
|
||||
}
|
||||
|
||||
if (s->verbose) PrintInterpreterState(interpreter.get());
|
||||
|
||||
// get input dimension from the input tensor metadata
|
||||
// assuming one input only
|
||||
TfLiteIntArray* dims = interpreter->tensor(input)->dims;
|
||||
int wanted_height = dims->data[1];
|
||||
int wanted_width = dims->data[2];
|
||||
int wanted_channels = dims->data[3];
|
||||
|
||||
if (s->input_floating) {
|
||||
downsize<float>(interpreter->typed_tensor<float>(input), in, image_height,
|
||||
image_width, image_channels, wanted_height, wanted_width,
|
||||
wanted_channels, s);
|
||||
} else {
|
||||
downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
|
||||
image_height, image_width, image_channels, wanted_height,
|
||||
wanted_width, wanted_channels, s);
|
||||
}
|
||||
|
||||
struct timeval start_time, stop_time;
|
||||
gettimeofday(&start_time, NULL);
|
||||
for (int i = 0; i < s->loop_count; i++) {
|
||||
if (interpreter->Invoke() != kTfLiteOk) {
|
||||
LOG(FATAL) << "Failed to invoke tflite!\n";
|
||||
}
|
||||
}
|
||||
gettimeofday(&stop_time, NULL);
|
||||
LOG(INFO) << "invoked \n";
|
||||
LOG(INFO) << "average time: "
|
||||
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
|
||||
<< " ms \n";
|
||||
|
||||
const int output_size = 1000;
|
||||
const size_t num_results = 5;
|
||||
const float threshold = 0.001f;
|
||||
|
||||
std::vector<std::pair<float, int>> top_results;
|
||||
|
||||
if (s->input_floating) {
|
||||
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
|
||||
num_results, threshold, &top_results, s->input_floating);
|
||||
} else {
|
||||
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
|
||||
output_size, num_results, threshold, &top_results,
|
||||
s->input_floating);
|
||||
}
|
||||
|
||||
std::vector<string> labels;
|
||||
size_t label_count;
|
||||
|
||||
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
|
||||
exit(-1);
|
||||
|
||||
for (const auto& result : top_results) {
|
||||
const float confidence = result.first;
|
||||
const int index = result.second;
|
||||
LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void display_usage() {
|
||||
LOG(INFO) << "label_image\n"
|
||||
<< "--accelerated, -a: [0|1], use Android NNAPI or note\n"
|
||||
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
|
||||
<< "--input_floating, -f: [0|1] type of input layer is floating "
|
||||
"point numbers\n"
|
||||
<< "--input_mean, -b: input mean\n"
|
||||
<< "--input_std, -s: input standard deviation\n"
|
||||
<< "--image, -i: image_name.bmp\n"
|
||||
<< "--labels, -l: labels for the model\n"
|
||||
<< "--tflite_mode, -m: model_name.tflite\n"
|
||||
<< "--threads, -t: number of threads\n"
|
||||
<< "--verbose, -v: [0|1] print more information\n"
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
int Main(int argc, char** argv) {
|
||||
Settings s;
|
||||
|
||||
int c;
|
||||
while (1) {
|
||||
static struct option long_options[] = {
|
||||
{"accelerated", required_argument, 0, 'a'},
|
||||
{"count", required_argument, 0, 'c'},
|
||||
{"input_floating", required_argument, 0, 'f'},
|
||||
{"verbose", required_argument, 0, 'v'},
|
||||
{"image", required_argument, 0, 'i'},
|
||||
{"labels", required_argument, 0, 'l'},
|
||||
{"tflite_model", required_argument, 0, 'm'},
|
||||
{"threads", required_argument, 0, 't'},
|
||||
{"input_mean", required_argument, 0, 'b'},
|
||||
{"input_std", required_argument, 0, 's'},
|
||||
{0, 0, 0, 0}};
|
||||
|
||||
/* getopt_long stores the option index here. */
|
||||
int option_index = 0;
|
||||
|
||||
c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options,
|
||||
&option_index);
|
||||
|
||||
/* Detect the end of the options. */
|
||||
if (c == -1) break;
|
||||
|
||||
switch (c) {
|
||||
case 'a':
|
||||
s.accel = strtol( // NOLINT(runtime/deprecated_fn)
|
||||
optarg, (char**)NULL, 10);
|
||||
break;
|
||||
case 'b':
|
||||
s.input_mean = strtod(optarg, NULL);
|
||||
break;
|
||||
case 'c':
|
||||
s.loop_count = strtol( // NOLINT(runtime/deprecated_fn)
|
||||
optarg, (char**)NULL, 10);
|
||||
break;
|
||||
case 'f':
|
||||
s.input_floating = strtol( // NOLINT(runtime/deprecated_fn)
|
||||
optarg, (char**)NULL, 10);
|
||||
s.input_layer_type = "float";
|
||||
break;
|
||||
case 'i':
|
||||
s.input_bmp_name = optarg;
|
||||
break;
|
||||
case 'l':
|
||||
s.labels_file_name = optarg;
|
||||
break;
|
||||
case 'm':
|
||||
s.model_name = optarg;
|
||||
break;
|
||||
case 's':
|
||||
s.input_std = strtod(optarg, NULL);
|
||||
break;
|
||||
case 't':
|
||||
s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
|
||||
optarg, (char**)NULL, 10);
|
||||
break;
|
||||
case 'v':
|
||||
s.verbose = strtol( // NOLINT(runtime/deprecated_fn)
|
||||
optarg, (char**)NULL, 10);
|
||||
break;
|
||||
case 'h':
|
||||
case '?':
|
||||
/* getopt_long already printed an error message. */
|
||||
display_usage();
|
||||
exit(-1);
|
||||
default:
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
RunInference(&s);
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
return tflite::label_image::Main(argc, argv);
|
||||
}
|
36
tensorflow/contrib/lite/examples/label_image/label_image.h
Normal file
36
tensorflow/contrib/lite/examples/label_image/label_image.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/contrib/lite/string.h"
|
||||
|
||||
struct Settings {
|
||||
bool verbose = false;
|
||||
bool accel = false;
|
||||
bool input_floating = false;
|
||||
int loop_count = 1;
|
||||
float input_mean = 127.5f;
|
||||
float input_std = 127.5f;
|
||||
string model_name = "./mobilenet_quant_v1_224.tflite";
|
||||
string input_bmp_name = "./grace_hopper.bmp";
|
||||
string labels_file_name = "./labels.txt";
|
||||
string input_layer_type = "uint8_t";
|
||||
int number_of_threads = 4;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
|
74
tensorflow/contrib/lite/examples/label_image/label_image.md
Normal file
74
tensorflow/contrib/lite/examples/label_image/label_image.md
Normal file
@ -0,0 +1,74 @@
|
||||
label_image for TensorFlow Lite inspired by TensorFlow's label_image.
|
||||
|
||||
To build it for android ARMv8:
|
||||
```
|
||||
> bazel build --cxxopt=-std=c++11 \
|
||||
--crosstool_top=//external:android/crosstool \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--cpu=arm64-v8a \
|
||||
//tensorflow/contrib/lite/examples/label_image:label_image
|
||||
```
|
||||
or
|
||||
```
|
||||
> bazel build --config android_arm64 --cxxopt=-std=c++11 \
|
||||
//tensorflow/contrib/lite/examples/label_image:label_image
|
||||
```
|
||||
|
||||
To build it for android arm-v7a:
|
||||
```
|
||||
> bazel build --cxxopt=-std=c++11 \
|
||||
--crosstool_top=//external:android/crosstool \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--cpu=armeabi-v7a \
|
||||
//tensorflow/contrib/lite/examples/label_image:label_image
|
||||
```
|
||||
or
|
||||
```
|
||||
> bazel build --config android_arm --cxxopt=-std=c++11 \
|
||||
//tensorflow/contrib/lite/examples/label_image:label_image
|
||||
```
|
||||
|
||||
Build it for desktop machines (tested on Ubuntu and OS X)
|
||||
```
|
||||
> bazel build --config opt --cxxopt=-std=c++11 //tensorflow/contrib/lite/examples/label_image:label_image
|
||||
```
|
||||
To run it. Prepare `./mobilenet_quant_v1_224.tflite`, `./grace_hopper.bmp`, and `./labels.txt`.
|
||||
|
||||
Run it:
|
||||
```
|
||||
> ./label_image
|
||||
Loaded model ./mobilenet_quant_v1_224.tflite
|
||||
resolved reporter
|
||||
invoked
|
||||
average time: 100.986 ms
|
||||
0.439216: 653 military uniform
|
||||
0.372549: 458 bow tie
|
||||
0.0705882: 466 bulletproof vest
|
||||
0.0235294: 514 cornet
|
||||
0.0196078: 835 suit
|
||||
```
|
||||
Run `interpreter->Invoker()` 100 times:
|
||||
```
|
||||
> ./label_image -c 100
|
||||
Loaded model ./mobilenet_quant_v1_224.tflite
|
||||
resolved reporter
|
||||
invoked
|
||||
average time: 33.4694 ms
|
||||
...
|
||||
```
|
||||
|
||||
Run a floating point (`mobilenet_v1_1.0_224.tflite`) model,
|
||||
```
|
||||
> ./label_image -f 1 -m mobilenet_v1_1.0_224.tflite
|
||||
Loaded model mobilenet_v1_1.0_224.tflite
|
||||
resolved reporter
|
||||
invoked
|
||||
average time: 263.493 ms
|
||||
0.88615: 653 military uniform
|
||||
0.0422316: 440 bearskin
|
||||
0.0109948: 466 bulletproof vest
|
||||
0.0105327: 401 academic gown
|
||||
0.00947104: 723 ping-pong bal
|
||||
```
|
||||
|
||||
See the source code for other command line options.
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h"
|
||||
#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h"
|
||||
#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
namespace tflite {
|
||||
namespace label_image {
|
||||
|
||||
TEST(LabelImageTest, GraceHopper) {
|
||||
std::string lena_file =
|
||||
"tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp";
|
||||
int height, width, channels;
|
||||
Settings s;
|
||||
uint8_t *data;
|
||||
|
||||
data = read_bmp(lena_file, &width, &height, &channels, &s);
|
||||
ASSERT_EQ(height, 606);
|
||||
ASSERT_EQ(width, 517);
|
||||
ASSERT_EQ(channels, 3);
|
||||
|
||||
uint8_t *out = new uint8_t[606 * 517 * 3];
|
||||
downsize<uint8_t>(out, data, 606, 517, 3, 214, 214, 3, &s);
|
||||
ASSERT_EQ(out[0], 0x15);
|
||||
ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12);
|
||||
}
|
||||
|
||||
TEST(LabelImageTest, GetTopN) {
|
||||
uint8_t in[] = {1, 1, 2, 2, 4, 4, 16, 32, 128, 64};
|
||||
|
||||
std::vector<std::pair<float, int>> top_results;
|
||||
get_top_n<uint8_t>(in, 10, 5, 0.025, &top_results, false);
|
||||
ASSERT_EQ(top_results.size(), 4);
|
||||
ASSERT_EQ(top_results[0].second, 8);
|
||||
}
|
||||
|
||||
} // namespace label_image
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
BIN
tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp
vendored
Normal file
BIN
tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 919 KiB |
@ -1774,7 +1774,7 @@ inline int ANeuralNetworksExecution_setInput(
|
||||
* model. If the type is the same as specified when the model
|
||||
* was built, NULL can be passed.
|
||||
* @param memory The memory containing the data.
|
||||
* @param offset This specifies the location of the data whithin the memory.
|
||||
* @param offset This specifies the location of the data within the memory.
|
||||
* The offset is in bytes from the start of memory.
|
||||
* @param length The size in bytes of the data value.
|
||||
*
|
||||
@ -1841,7 +1841,7 @@ inline int ANeuralNetworksExecution_setOutput(
|
||||
* model. If the type is the same as specified when the model
|
||||
* was built, NULL can be passed.
|
||||
* @param memory The memory where the data is to be stored.
|
||||
* @param offset This specifies the location of the data whithin the memory.
|
||||
* @param offset This specifies the location of the data within the memory.
|
||||
* The offset is in bytes from the start of memory.
|
||||
* @param length The length in bytes of the data value.
|
||||
*
|
||||
|
@ -374,12 +374,72 @@ $(MARCH_OPTION) \
|
||||
ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS
|
||||
CXXFLAGS += -DENABLE_EXPERIMENTAL_HEXNN_OPS
|
||||
endif
|
||||
|
||||
OBJDIR := $(OBJDIR)android_$(ANDROID_ARCH)/
|
||||
LIBDIR := $(LIBDIR)android_$(ANDROID_ARCH)/
|
||||
BINDIR := $(BINDIR)android_$(ANDROID_ARCH)/
|
||||
DEPDIR := $(DEPDIR)android_$(ANDROID_ARCH)/
|
||||
|
||||
ifeq ($(BUILD_FOR_TEGRA),1)
|
||||
NVCC := $(JETPACK)/cuda/bin/nvcc
|
||||
NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DNVIDIA_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3
|
||||
CXXFLAGS4NVCC =\
|
||||
-DIS_SLIM_BUILD \
|
||||
-DNVIDIA_TEGRA \
|
||||
-fno-exceptions \
|
||||
-DNDEBUG $(OPTFLAGS) \
|
||||
-march=armv8-a \
|
||||
-fPIE \
|
||||
-D__ANDROID_TYPES_FULL__ \
|
||||
--sysroot $(NDK_ROOT)/platforms/android-21/arch-arm64
|
||||
|
||||
CXXFLAGS +=\
|
||||
-DGOOGLE_CUDA=1 \
|
||||
-D__ANDROID_TYPES_FULL__ \
|
||||
-DNVIDIA_TEGRA \
|
||||
-DEIGEN_AVOID_STL_ARRAY \
|
||||
-DEIGEN_HAS_C99_MATH \
|
||||
-DLANG_CXX11 -DTENSORFLOW_USE_EIGEN_THREADPOOL -DTF_EXTRA_CUDA_CAPABILITIES=5.3
|
||||
|
||||
INCLUDES += \
|
||||
-Itensorflow/core/kernels \
|
||||
-I$(MAKEFILE_DIR)/downloads/cub \
|
||||
-I$(MAKEFILE_DIR)/downloads/cub/cub_archive/cub/device \
|
||||
-Ithird_party/toolchains/gpus/cuda \
|
||||
-I$(JETPACK)/cuda/include \
|
||||
-I$(JETPACK) \
|
||||
-I$(JETPACK)/cuDNN/aarch64 \
|
||||
-I$(JETPACK)/cuda/extras/CUPTI/include
|
||||
|
||||
|
||||
LIBS += \
|
||||
-ltfcuda \
|
||||
-lcudart_static \
|
||||
-lcudnn \
|
||||
-lcublas_static \
|
||||
-lcufftw_static \
|
||||
-lcusolver_static \
|
||||
-lcusparse_static \
|
||||
-lcufft \
|
||||
-lcuda \
|
||||
-lculibos \
|
||||
-lcurand_static
|
||||
|
||||
OBJDIR := $(OBJDIR)Tegra/
|
||||
LIBDIR := $(LIBDIR)Tegra/
|
||||
BINDIR := $(BINDIR)Tegra/
|
||||
DEPDIR := $(DEPDIR)Tegra/
|
||||
|
||||
TEGRA_LIBS := \
|
||||
-L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib \
|
||||
-L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib/stubs \
|
||||
-L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib64 \
|
||||
-L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib64/stubs \
|
||||
-L$(JETPACK)/cuDNN/aarch64/cuda/lib64 \
|
||||
-L$(LIBDIR)
|
||||
|
||||
CUDA_LIB_DEPS := $(LIBDIR)libtfcuda.a
|
||||
else
|
||||
OBJDIR := $(OBJDIR)android_$(ANDROID_ARCH)/
|
||||
LIBDIR := $(LIBDIR)android_$(ANDROID_ARCH)/
|
||||
BINDIR := $(BINDIR)android_$(ANDROID_ARCH)/
|
||||
DEPDIR := $(DEPDIR)android_$(ANDROID_ARCH)/
|
||||
endif # ifeq ($(BUILD_FOR_TEGRA),1)
|
||||
endif # ANDROID
|
||||
# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
|
||||
|
||||
@ -585,6 +645,65 @@ $(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/clusters/single_machine.*)
|
||||
|
||||
ifeq ($(BUILD_FOR_TEGRA),1)
|
||||
CORE_CC_ALL_SRCS := \
|
||||
$(wildcard tensorflow/core/*.cc) \
|
||||
$(wildcard tensorflow/core/common_runtime/*.cc) \
|
||||
$(wildcard tensorflow/core/common_runtime/gpu/*.cc) \
|
||||
$(wildcard tensorflow/core/framework/*.cc) \
|
||||
$(wildcard tensorflow/core/graph/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*/*.cc) \
|
||||
$(wildcard tensorflow/core/platform/*/*/*.cc) \
|
||||
$(wildcard tensorflow/core/util/*.cc) \
|
||||
$(wildcard tensorflow/core/util/*/*.cc) \
|
||||
$(wildcard tensorflow/cc/training/*.cc) \
|
||||
$(wildcard tensorflow/stream_executor/*.cc) \
|
||||
$(wildcard tensorflow/stream_executor/*/*.cc) \
|
||||
$(wildcard tensorflow/core/grappler/optimizers/*.cc) \
|
||||
$(wildcard tensorflow/core/grappler/*.cc) \
|
||||
$(wildcard tensorflow/core/grappler/costs/*.cc) \
|
||||
$(wildcard tensorflow/core/grappler/clusters/*.cc) \
|
||||
$(wildcard tensorflow/core/grappler/utils/*.cc) \
|
||||
$(wildcard tensorflow/core/lib/core/*.cc) \
|
||||
$(wildcard tensorflow/core/lib/*/*.cc) \
|
||||
tensorflow/core/grappler/inputs/utils.cc \
|
||||
tensorflow/core/kernels/concat_lib_gpu.cc \
|
||||
tensorflow/core/kernels/cuda_solvers.cc \
|
||||
tensorflow/core/kernels/cudnn_pooling_gpu.cc \
|
||||
tensorflow/core/kernels/dense_update_functor.cc \
|
||||
tensorflow/core/kernels/fractional_avg_pool_op.cc \
|
||||
tensorflow/core/kernels/fractional_max_pool_op.cc \
|
||||
tensorflow/core/kernels/fractional_pool_common.cc \
|
||||
tensorflow/core/kernels/pooling_ops_3d.cc \
|
||||
tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
|
||||
|
||||
CORE_CC_EXCLUDE_SRCS := \
|
||||
$(wildcard tensorflow/core/*/*test.cc) \
|
||||
$(wildcard tensorflow/core/*/*testutil*) \
|
||||
$(wildcard tensorflow/core/*/*testlib*) \
|
||||
$(wildcard tensorflow/core/*/*/*test.cc) \
|
||||
$(wildcard tensorflow/core/*/*/*testutil*) \
|
||||
$(wildcard tensorflow/core/framework/op_gen_lib.cc) \
|
||||
$(wildcard tensorflow/core/lib/gif/*) \
|
||||
$(wildcard tensorflow/core/lib/jpeg/*) \
|
||||
$(wildcard tensorflow/core/lib/png/*) \
|
||||
$(wildcard tensorflow/core/lib/db/*) \
|
||||
$(wildcard tensorflow/core/platform/jpeg.*) \
|
||||
$(wildcard tensorflow/core/platform/png.*) \
|
||||
$(wildcard tensorflow/core/platform/cloud/*) \
|
||||
$(wildcard tensorflow/core/platform/s3/*) \
|
||||
$(wildcard tensorflow/core/platform/windows/*) \
|
||||
$(wildcard tensorflow/core/*/*/*testlib*) \
|
||||
$(wildcard tensorflow/cc/training/*test.cc) \
|
||||
tensorflow/core/lib/io/record_reader.cc \
|
||||
tensorflow/core/util/cuda_kernel_helper_test.cu.cc
|
||||
|
||||
CUDA_CC_SRCS := $(wildcard tensorflow/core/kernels/*.cu.cc)
|
||||
CUDA_CC_OBJS := $(addprefix $(OBJDIR), $(CUDA_CC_SRCS:.cc=.o))
|
||||
endif # TEGRA
|
||||
|
||||
# Filter out all the excluded files.
|
||||
TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
|
||||
# Add in any extra files that don't fit the patterns easily
|
||||
@ -637,11 +756,23 @@ $(LIB_PATH): $(LIB_OBJS)
|
||||
@mkdir -p $(dir $@)
|
||||
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
|
||||
|
||||
$(BENCHMARK_NAME): $(BENCHMARK_OBJS) $(LIB_PATH)
|
||||
$(BENCHMARK_NAME): $(BENCHMARK_OBJS) $(LIB_PATH) $(CUDA_LIB_DEPS)
|
||||
@mkdir -p $(dir $@)
|
||||
$(CXX) $(CXXFLAGS) $(INCLUDES) \
|
||||
-o $(BENCHMARK_NAME) $(BENCHMARK_OBJS) \
|
||||
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
|
||||
$(LIBFLAGS) $(TEGRA_LIBS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
|
||||
|
||||
# NVCC compilation rules for Tegra
|
||||
ifeq ($(BUILD_FOR_TEGRA),1)
|
||||
$(OBJDIR)%.cu.o: %.cu.cc
|
||||
@mkdir -p $(dir $@)
|
||||
@mkdir -p $(dir $(DEPDIR)$*)
|
||||
$(NVCC) $(NVCCFLAGS) -Xcompiler "$(CXXFLAGS4NVCC) $(DEPFLAGS)" $(INCLUDES) -c $< -o $@
|
||||
|
||||
$(LIBDIR)libtfcuda.a: $(CUDA_CC_OBJS)
|
||||
@mkdir -p $(dir $@)
|
||||
$(AR) $(ARFLAGS) $@ $(CUDA_CC_OBJS)
|
||||
endif
|
||||
|
||||
# Matches on the normal hand-written TensorFlow C++ source files.
|
||||
$(OBJDIR)%.o: %.cc | $(PBT_GEN_FILES)
|
||||
@ -730,6 +861,7 @@ clean_except_protobuf_libs:
|
||||
cleantarget:
|
||||
rm -rf $(OBJDIR)
|
||||
rm -rf $(BINDIR)
|
||||
rm -rf $(LIBDIR)
|
||||
|
||||
$(DEPDIR)/%.d: ;
|
||||
.PRECIOUS: $(DEPDIR)/%.d
|
||||
|
@ -26,7 +26,7 @@ usage() {
|
||||
echo "-x [hexagon library path] copy and hexagon libraries in the specified path"
|
||||
echo "-a [architecture] Architecture of target android [default=armeabi-v7a] \
|
||||
(supported architecture list: \
|
||||
arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64)"
|
||||
arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)"
|
||||
exit 1
|
||||
}
|
||||
|
||||
@ -50,6 +50,26 @@ while getopts "Es:t:Tx:a:" opt_name; do
|
||||
done
|
||||
shift $((OPTIND - 1))
|
||||
|
||||
if [ "$ARCH" == "tegra" ]; then
|
||||
if [[ -z "${JETPACK}" ]]; then
|
||||
export JETPACK="$HOME/JetPack_Android_3.0"
|
||||
fi
|
||||
if [ ! -d ${JETPACK} ]; then
|
||||
echo "Can't find Jetpack at ${JETPACK}"
|
||||
echo "Set JETPACK=<path to Jetpack Android> to specify a non-default Jetpack path"
|
||||
exit -1
|
||||
fi
|
||||
if [ ! -d ${JETPACK}/cuda ]; then
|
||||
ln -s $(ls -d ${JETPACK}/cuda-*/|sort -r|head -n1) ${JETPACK}/cuda
|
||||
fi
|
||||
if [ ! -d ${JETPACK}/cuda ]; then
|
||||
ln -s $(ls -d ${JETPACK}/cuda-*/|sort -r|head -n1) ${JETPACK}/cuda
|
||||
fi
|
||||
|
||||
export BUILD_FOR_TEGRA=1
|
||||
ARCH="arm64-v8a"
|
||||
fi
|
||||
|
||||
# Make sure we're in the correct directory, at the root of the source tree.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)"
|
||||
cd "${SCRIPT_DIR}"/../../../
|
||||
|
@ -34,6 +34,7 @@ PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.
|
||||
RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"
|
||||
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
|
||||
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
# so work around it by patching the source.
|
||||
@ -82,6 +83,7 @@ download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf"
|
||||
download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2"
|
||||
download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d"
|
||||
download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl"
|
||||
download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive"
|
||||
|
||||
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
|
||||
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
|
||||
|
@ -48,10 +48,10 @@ INFERENCE_OBJS := $(addprefix $(OBJDIR), $(INFERENCE_SRCS:.cc=.o))
|
||||
INFERENCE_SO_NAME := libtensorflow_inference.so
|
||||
INFERENCE_SO_PATH := $(LIBDIR)$(INFERENCE_SO_NAME)
|
||||
|
||||
$(INFERENCE_SO_PATH): $(LIB_OBJS) $(INFERENCE_OBJS)
|
||||
$(INFERENCE_SO_PATH): $(LIB_OBJS) $(INFERENCE_OBJS) $(CUDA_LIB_DEPS)
|
||||
@mkdir -p $(dir $@)
|
||||
$(CXX) $(CXXFLAGS) $(INCLUDES) \
|
||||
-o $@ $(INFERENCE_OBJS) $(LIB_OBJS) \
|
||||
-o $@ $(INFERENCE_OBJS) $(LIB_OBJS) $(TEGRA_LIBS) \
|
||||
$(LIBFLAGS) $(LDFLAGS) \
|
||||
-shared -Wl,-soname,$(INFERENCE_SO_NAME) \
|
||||
$(LIBS)
|
||||
|
@ -44,7 +44,7 @@ py_library(
|
||||
"naming.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:private"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/py2tf/convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
|
@ -9,7 +9,7 @@ END
|
||||
in_arg {
|
||||
name: "axis"
|
||||
description: <<END
|
||||
A `Tensor` of type `int64` (default: 0). The axis of the Tensor to
|
||||
A `Tensor` of type `int32` (default: None). The axis of the Tensor to
|
||||
find the unique elements.
|
||||
END
|
||||
}
|
||||
@ -26,12 +26,15 @@ A 1-D Tensor. Has the same type as x that contains the index of each
|
||||
value of x in the output y.
|
||||
END
|
||||
}
|
||||
summary: "Finds unique elements in a 1-D tensor."
|
||||
summary: "Finds unique elements along an axis of a tensor."
|
||||
description: <<END
|
||||
This operation returns a tensor `y` containing all of the unique elements of `x`
|
||||
sorted in the same order that they occur in `x`. This operation also returns a
|
||||
tensor `idx` the same size as `x` that contains the index of each value of `x`
|
||||
in the unique output `y`. In other words:
|
||||
This operation either returns a tensor `y` containing unique elements
|
||||
along the `axis` of a tensor. The returned unique elements is sorted
|
||||
in the same order as they occur along `axis` in `x`.
|
||||
This operation also returns a tensor `idx` that is the same size as
|
||||
the number of the elements in `x` along the `axis` dimension. It
|
||||
contains the index in the unique output `y`.
|
||||
In other words, for an `1-D` tensor `x` with `axis = None:
|
||||
|
||||
`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
|
||||
|
||||
@ -43,5 +46,30 @@ y, idx = unique(x)
|
||||
y ==> [1, 2, 4, 7, 8]
|
||||
idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
|
||||
```
|
||||
|
||||
For an `2-D` tensor `x` with `axis = 0`:
|
||||
|
||||
```
|
||||
# tensor 'x' is [[1, 0, 0],
|
||||
# [1, 0, 0],
|
||||
# [2, 0, 0]]
|
||||
y, idx = unique(x, axis=0)
|
||||
y ==> [[1, 0, 0],
|
||||
[2, 0, 0]]
|
||||
idx ==> [0, 0, 1]
|
||||
```
|
||||
|
||||
For an `2-D` tensor `x` with `axis = 1`:
|
||||
|
||||
```
|
||||
# tensor 'x' is [[1, 0, 0],
|
||||
# [1, 0, 0],
|
||||
# [2, 0, 0]]
|
||||
y, idx = unique(x, axis=1)
|
||||
y ==> [[1, 0],
|
||||
[1, 0],
|
||||
[2, 0]]
|
||||
idx ==> [0, 1, 1]
|
||||
```
|
||||
END
|
||||
}
|
||||
|
4
tensorflow/core/api_def/python_api/api_def_Unique.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Unique.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Unique"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "UniqueV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -149,33 +149,25 @@ FunctionDef XTimes16() {
|
||||
{{"y", "y:y:0"}});
|
||||
}
|
||||
|
||||
FunctionDef WXPlusB(){return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{
|
||||
{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{
|
||||
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
|
||||
#ifdef INTEL_MKL
|
||||
}},
|
||||
#else
|
||||
FunctionDef WXPlusB() {
|
||||
return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{{"T", "$T"},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false},
|
||||
{"_kernel", "eigen"}}},
|
||||
#endif
|
||||
{
|
||||
{"y"}, "Add", {"mm", "b"}, {
|
||||
{ "T", "$T" }
|
||||
}
|
||||
}
|
||||
});
|
||||
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
|
||||
}
|
||||
|
||||
FunctionDef Swap() {
|
||||
|
@ -76,7 +76,7 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
|
||||
} // namespace numext
|
||||
} // namespace Eigen
|
||||
|
||||
#ifdef COMPILER_MSVC
|
||||
#if defined(COMPILER_MSVC) && !defined(__clang__)
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<Eigen::half> {
|
||||
|
@ -52,7 +52,8 @@ limitations under the License.
|
||||
#undef REGISTER_PARTITION
|
||||
*/
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
|
||||
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \
|
||||
defined(NVIDIA_TEGRA)
|
||||
|
||||
// All types are supported, so all macros are invoked.
|
||||
//
|
||||
|
@ -38,6 +38,7 @@ load(
|
||||
"tf_mkl_kernel_library",
|
||||
"cc_header_only_library",
|
||||
"if_not_windows",
|
||||
"if_override_eigen_strong_inline",
|
||||
)
|
||||
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
@ -3069,6 +3070,10 @@ tf_kernel_library(
|
||||
":xsmm": ["xsmm_conv2d.h"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
|
||||
# So that it doesn't take 20 minutes to compile conv_grad_ops_3d.cc and conv_ops_3d.cc
|
||||
# on Windows. See https://github.com/tensorflow/tensorflow/issues/10521
|
||||
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
|
||||
defines = select({
|
||||
":xsmm": [
|
||||
"TENSORFLOW_USE_LIBXSMM",
|
||||
|
@ -512,7 +512,7 @@ struct GreaterThan {
|
||||
constexpr bool operator()(int a, int b) const { return a > b; }
|
||||
};
|
||||
|
||||
// For each data type, the tile size posibility frontier denotes the tile size
|
||||
// For each data type, the tile size possibility frontier denotes the tile size
|
||||
// combinations that consume the most computational resources constrained by
|
||||
// - number of threads per SM limit,
|
||||
// - limit on size of the short dimension (<=15) due to the definition of
|
||||
|
@ -535,13 +535,15 @@ struct MatMulFunctor<SYCLDevice, T> {
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
||||
#define REGISTER_CPU_EIGEN(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>)
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
||||
REGISTER_CPU_EIGEN(T);
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -556,9 +558,14 @@ struct MatMulFunctor<SYCLDevice, T> {
|
||||
#if defined(INTEL_MKL)
|
||||
// MKL does not support half and int32 types for matrix-multiplication, so
|
||||
// register the kernel to use default Eigen based implementations for these
|
||||
// types
|
||||
// types. Registration for NO-LABEL version is in mkl_matmul_op.cc
|
||||
TF_CALL_float(REGISTER_CPU_EIGEN);
|
||||
TF_CALL_double(REGISTER_CPU_EIGEN);
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
|
||||
TF_CALL_int32(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU_EIGEN);
|
||||
TF_CALL_complex128(REGISTER_CPU_EIGEN);
|
||||
#else
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
@ -63,8 +64,17 @@ class UniqueOp : public OpKernel {
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
|
||||
errors::InvalidArgument("unique expects a 1D vector."));
|
||||
} else {
|
||||
auto axis_vec = axis_tensor.vec<int64>();
|
||||
axis = axis_vec(0);
|
||||
OP_REQUIRES(context,
|
||||
(axis_tensor.dtype() == DT_INT32 ||
|
||||
axis_tensor.dtype() == DT_INT64),
|
||||
errors::InvalidArgument(
|
||||
"axis tensor should be int32 or int64, but got ",
|
||||
axis_tensor.dtype()));
|
||||
if (axis_tensor.dtype() == DT_INT32) {
|
||||
axis = internal::SubtleMustCopy(axis_tensor.scalar<int32>()());
|
||||
} else {
|
||||
axis = internal::SubtleMustCopy(axis_tensor.scalar<int64>()());
|
||||
}
|
||||
axis = axis < 0 ? axis + input.dims() : axis;
|
||||
OP_REQUIRES(context, 0 <= axis && axis < input.dims(),
|
||||
errors::InvalidArgument("axis has to be between [0, ",
|
||||
|
@ -1165,10 +1165,11 @@ REGISTER_OP("Unique")
|
||||
|
||||
REGISTER_OP("UniqueV2")
|
||||
.Input("x: T")
|
||||
.Input("axis: int64")
|
||||
.Input("axis: Taxis")
|
||||
.Output("y: T")
|
||||
.Output("idx: out_idx")
|
||||
.Attr("T: type")
|
||||
.Attr("Taxis: {int32,int64} = DT_INT64")
|
||||
.Attr("out_idx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||
|
@ -112,8 +112,8 @@ accelerator_micros and cpu_micros. Note: cpu and accelerator can run in parallel
|
||||
|
||||
`-account_displayed_op_only`: If True, only account the statistics of ops eventually displayed. If False, account all op statistics matching -account_type_regexes recursively.
|
||||
|
||||
|
||||
Notes: See <b>overview</b> sesion on how does above options play with each other to decide the output and counting.
|
||||
Notes: See <b>overview</b> session on how does above options play with each
|
||||
other to decide the output and counting.
|
||||
|
||||
`-select`: Comma-separated list of attributes to show. Supported attributes:
|
||||
[bytes|peak_bytes|residual_bytes|output_bytes|micros|accelerator_micros|cpu_micros|params|float_ops|occurrence|tensor_value|device|op_types|input_shapes].
|
||||
|
@ -291,7 +291,7 @@ calls @{tf.decode_csv} to parse a single line into its features
|
||||
and the label. Since Estimators require that features be represented as a
|
||||
dictionary, we rely on Python's built-in `dict` and `zip` functions to build
|
||||
that dictionary. The feature names are the keys of that dictionary.
|
||||
We then then call the dictionary's `pop` method to remove the label field from
|
||||
We then call the dictionary's `pop` method to remove the label field from
|
||||
the features dictionary:
|
||||
|
||||
``` python
|
||||
|
@ -120,7 +120,7 @@ data set.
|
||||
text patterns.
|
||||
|
||||
Further useful articles are
|
||||
[How to Use t-SNE Effectively](distill.pub/2016/misread-tsne/) and
|
||||
[How to Use t-SNE Effectively](https://distill.pub/2016/misread-tsne/) and
|
||||
[Principal Component Analysis Explained Visually](http://setosa.io/ev/principal-component-analysis/).
|
||||
|
||||
### Exploration
|
||||
|
@ -479,10 +479,10 @@ does not specify one.
|
||||
### Serving the exported model locally
|
||||
|
||||
For local deployment, you can serve your model using
|
||||
[TensorFlow Serving](http://github.com/tensorflow/serving), an open-source project that loads a
|
||||
SavedModel and exposes it as a [gRPC](http://www.grpc.io/) service.
|
||||
[TensorFlow Serving](https://github.com/tensorflow/serving), an open-source project that loads a
|
||||
SavedModel and exposes it as a [gRPC](https://www.grpc.io/) service.
|
||||
|
||||
First, [install TensorFlow Serving](http://github.com/tensorflow/serving).
|
||||
First, [install TensorFlow Serving](https://github.com/tensorflow/serving).
|
||||
|
||||
Then build and run the local model server, substituting `$export_dir_base` with
|
||||
the path to the SavedModel you exported above:
|
||||
|
@ -51,6 +51,16 @@ tf_cc_binary(
|
||||
}),
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "label_image_py",
|
||||
srcs = ["label_image.py"],
|
||||
main = "label_image.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -73,10 +73,23 @@ Python than the Python code mentioned in the
|
||||
[Inception tutorial](https://www.tensorflow.org/tutorials/image_recognition/).
|
||||
and could be easier to add visualization or debug code.
|
||||
|
||||
With tensorflow python package installed, you can run it like:
|
||||
|
||||
`bazel-bin/tensorflow/examples/label_image/label_image_py` should be there after
|
||||
```bash
|
||||
$ bazel build tensorflow/examples/label_image/...
|
||||
```
|
||||
|
||||
Run
|
||||
|
||||
```bash
|
||||
$ bazel-bin/tensorflow/examples/label_image/label_image_py
|
||||
```
|
||||
|
||||
Or, with tensorflow python package installed, you can run it like:
|
||||
```bash
|
||||
$ python3 tensorflow/examples/label_image/label_image.py
|
||||
```
|
||||
|
||||
And get result similar to this:
|
||||
```
|
||||
military uniform 0.834305
|
||||
|
@ -14,6 +14,7 @@ load(
|
||||
"tf_copts",
|
||||
"tf_custom_op_library",
|
||||
"tf_java_test",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
java_library(
|
||||
@ -113,10 +114,12 @@ cc_library(
|
||||
name = "java_op_gen_lib",
|
||||
srcs = [
|
||||
"src/gen/cc/op_generator.cc",
|
||||
"src/gen/cc/source_writer.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"src/gen/cc/java_defs.h",
|
||||
"src/gen/cc/op_generator.h",
|
||||
"src/gen/cc/source_writer.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
@ -305,6 +308,20 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "source_writer_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"src/gen/cc/source_writer_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":java_op_gen_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflow_jni",
|
||||
srcs = select({
|
||||
|
62
tensorflow/java/src/gen/cc/source_writer.cc
Normal file
62
tensorflow/java/src/gen/cc/source_writer.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/java/src/gen/cc/source_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
SourceWriter& SourceWriter::Append(const StringPiece& str) {
|
||||
if (!str.empty()) {
|
||||
if (newline_) {
|
||||
DoAppend(left_margin_ + line_prefix_);
|
||||
newline_ = false;
|
||||
}
|
||||
DoAppend(str);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SourceWriter& SourceWriter::Write(const string& str) {
|
||||
size_t line_pos = 0;
|
||||
do {
|
||||
size_t start_pos = line_pos;
|
||||
line_pos = str.find('\n', start_pos);
|
||||
if (line_pos != string::npos) {
|
||||
++line_pos;
|
||||
Append(StringPiece(str.data() + start_pos, line_pos - start_pos));
|
||||
newline_ = true;
|
||||
} else {
|
||||
Append(StringPiece(str.data() + start_pos, str.size() - start_pos));
|
||||
}
|
||||
} while (line_pos != string::npos && line_pos < str.size());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
SourceWriter& SourceWriter::EndLine() {
|
||||
Append("\n");
|
||||
newline_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
SourceWriter& SourceWriter::Indent(int tab) {
|
||||
left_margin_.resize(std::max(static_cast<int>(left_margin_.size() + tab), 0),
|
||||
' ');
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
133
tensorflow/java/src/gen/cc/source_writer.h
Normal file
133
tensorflow/java/src/gen/cc/source_writer.h
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_
|
||||
#define TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A utility class for writing source code, normally generated at
|
||||
// compile-time.
|
||||
//
|
||||
// Source writers are language-agnostic and therefore only expose generic
|
||||
// methods common to most languages. Extend or wrap this class to implement
|
||||
// language-specific features.
|
||||
//
|
||||
// Note: if you are looking to reuse this class for generating code in another
|
||||
// language than Java, please do by moving it at the '//tensorflow/core/lib/io'
|
||||
// level.
|
||||
class SourceWriter {
|
||||
public:
|
||||
virtual ~SourceWriter() = default;
|
||||
|
||||
// Returns true if the writer is at the beginnig of a new line
|
||||
bool newline() const { return newline_; }
|
||||
|
||||
// Appends a piece of code or text.
|
||||
//
|
||||
// It is expected that no newline character is present in the data provided,
|
||||
// otherwise Write() must be used.
|
||||
SourceWriter& Append(const StringPiece& str);
|
||||
|
||||
// Writes a block of code or text.
|
||||
//
|
||||
// The data might potentially contain newline characters, therefore it will
|
||||
// be scanned to ensure that each line is indented and prefixed properly,
|
||||
// making it a bit slower than Append().
|
||||
SourceWriter& Write(const string& text);
|
||||
|
||||
// Appends a newline character and start writing on a new line.
|
||||
SourceWriter& EndLine();
|
||||
|
||||
// Indents following lines with white spaces.
|
||||
//
|
||||
// Indentation is cumulative, i.e. the provided tabulation is added to the
|
||||
// current indentation value. If the tabulation is negative, the operation
|
||||
// will outdent the source code, until the indentation reaches 0 again.
|
||||
//
|
||||
// For example, calling Indent(2) twice will indent code with 4 white
|
||||
// spaces. Then calling Indent(-2) will outdent the code back to 2 white
|
||||
// spaces.
|
||||
SourceWriter& Indent(int tab);
|
||||
|
||||
// Prefixes following lines with provided character(s).
|
||||
//
|
||||
// A common use case of a prefix is for commenting or documenting the code.
|
||||
//
|
||||
// The prefix is written after the indentation, For example, invoking
|
||||
// Indent(2)->Prefix("//") will result in prefixing lines with " //".
|
||||
//
|
||||
// An empty value ("") will remove any line prefix that was previously set.
|
||||
SourceWriter& Prefix(const char* line_prefix) {
|
||||
line_prefix_ = line_prefix;
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void DoAppend(const StringPiece& str) = 0;
|
||||
|
||||
private:
|
||||
string left_margin_;
|
||||
string line_prefix_;
|
||||
bool newline_ = true;
|
||||
};
|
||||
|
||||
// A writer that outputs source code into a file.
|
||||
//
|
||||
// Note: the writer does not acquire the ownership of the file being passed in
|
||||
// parameter.
|
||||
class SourceFileWriter : public SourceWriter {
|
||||
public:
|
||||
explicit SourceFileWriter(WritableFile* file) : file_(file) {}
|
||||
virtual ~SourceFileWriter() = default;
|
||||
|
||||
protected:
|
||||
void DoAppend(const StringPiece& str) override {
|
||||
TF_CHECK_OK(file_->Append(str));
|
||||
}
|
||||
|
||||
private:
|
||||
WritableFile* file_;
|
||||
};
|
||||
|
||||
// A writer that outputs source code into a string buffer.
|
||||
class SourceBufferWriter : public SourceWriter {
|
||||
public:
|
||||
SourceBufferWriter() : owns_buffer_(true), buffer_(new string()) {}
|
||||
explicit SourceBufferWriter(string* buffer)
|
||||
: owns_buffer_(false), buffer_(buffer) {}
|
||||
virtual ~SourceBufferWriter() {
|
||||
if (owns_buffer_) delete buffer_;
|
||||
}
|
||||
const string& str() { return *buffer_; }
|
||||
|
||||
protected:
|
||||
void DoAppend(const StringPiece& str) override {
|
||||
buffer_->append(str.begin(), str.end());
|
||||
}
|
||||
|
||||
private:
|
||||
bool owns_buffer_;
|
||||
string* buffer_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_
|
215
tensorflow/java/src/gen/cc/source_writer_test.cc
Normal file
215
tensorflow/java/src/gen/cc/source_writer_test.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/java/src/gen/cc/source_writer.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(AppendTest, SingleLineText) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(AppendTest, MultiLineText) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(AppendTest, MultiLineTextWithIndent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Indent(2).Append("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = " You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(AppendTest, MultiLineTextWithPrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Prefix("--").Append("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = "--You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(AppendTest, MultiLineTextWithIndentAndPrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Indent(2).Prefix("--").Append("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = " --You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(WriteTest, SingleLineText) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Write("You say goodbye and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(WriteTest, MultiLineText) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Write("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(WriteTest, MultiLineTextWithIndent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Indent(2).Write("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = " You say goodbye\n and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(WriteTest, MultiLineTextWithPrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Prefix("--").Write("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = "--You say goodbye\n--and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(WriteTest, MultiLineTextWithIndentAndPrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Indent(2).Prefix("--").Write("You say goodbye\nand I say hello!");
|
||||
|
||||
const char* expected = " --You say goodbye\n --and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, Basic) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye").EndLine().Append("and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, Indent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Indent(2)
|
||||
.Append("and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\n and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, IndentAndOutdent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Indent(2)
|
||||
.Append("and I say hello!")
|
||||
.EndLine()
|
||||
.Indent(-2)
|
||||
.Append("Hello, hello!");
|
||||
|
||||
const char* expected = "You say goodbye\n and I say hello!\nHello, hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, Prefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Prefix("--")
|
||||
.Append("and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\n--and I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, PrefixAndRemovePrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Prefix("--")
|
||||
.Append("and I say hello!")
|
||||
.EndLine()
|
||||
.Prefix("")
|
||||
.Append("Hello, hello!");
|
||||
|
||||
const char* expected = "You say goodbye\n--and I say hello!\nHello, hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, IndentAndPrefixAndOutdentAndRemovePrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Indent(2)
|
||||
.Prefix("--")
|
||||
.Append("and I say hello!")
|
||||
.EndLine()
|
||||
.Indent(-2)
|
||||
.Prefix("")
|
||||
.Append("Hello, hello!");
|
||||
|
||||
const char* expected = "You say goodbye\n --and I say hello!\nHello, hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, NegativeIndent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Indent(-10)
|
||||
.Append("and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, CumulativeIndent) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Indent(2)
|
||||
.Append("and I say hello!")
|
||||
.EndLine()
|
||||
.Indent(2)
|
||||
.Append("Hello, hello!");
|
||||
|
||||
const char* expected =
|
||||
"You say goodbye\n and I say hello!\n Hello, hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
TEST(MarginTest, EmptyPrefix) {
|
||||
SourceBufferWriter writer;
|
||||
writer.Append("You say goodbye")
|
||||
.EndLine()
|
||||
.Prefix("")
|
||||
.Append("and I say hello!");
|
||||
|
||||
const char* expected = "You say goodbye\nand I say hello!";
|
||||
ASSERT_STREQ(expected, writer.str().data());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -67,6 +67,7 @@ genrule(
|
||||
genrule(
|
||||
name = "copy_jni_md_h",
|
||||
srcs = select({
|
||||
"//tensorflow:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"],
|
||||
"//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
|
||||
"//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
|
||||
}),
|
||||
|
@ -186,6 +186,7 @@ tf_py_test(
|
||||
":client_testlib",
|
||||
":platform",
|
||||
],
|
||||
tags = ["no_windows"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -2679,7 +2679,7 @@ class OutputTypesTest(test_util.TensorFlowTestCase):
|
||||
with g.as_default():
|
||||
x = constant_op.constant([1, 1, 2, 4, 4, 4, 7, 8, 8],
|
||||
dtype=dtypes.double)
|
||||
y, _ = gen_array_ops.unique(x)
|
||||
y, _ = gen_array_ops._unique(x)
|
||||
self.assertEqual([types_pb2.DT_DOUBLE, types_pb2.DT_INT32],
|
||||
y.op._output_types) # pylint: disable=protected-access
|
||||
|
||||
|
@ -441,6 +441,24 @@ class CreluTest(test.TestCase):
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
|
||||
use_gpu=True)
|
||||
|
||||
def testNumbersWithAxis0(self):
|
||||
with self.test_session():
|
||||
crelu = nn_ops.crelu(
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
|
||||
tf_relu = crelu.eval()
|
||||
np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1],
|
||||
[0, 3, 0, 7, 0]])
|
||||
self.assertAllEqual(np_crelu, tf_relu)
|
||||
|
||||
def testNumbersWithAxis1(self):
|
||||
with self.test_session():
|
||||
crelu = nn_ops.crelu(
|
||||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
|
||||
tf_relu = crelu.eval()
|
||||
np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1],
|
||||
[1, 0, 5, 0, 9, 0, 3, 0, 7, 0]])
|
||||
self.assertAllEqual(np_crelu, tf_relu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -63,23 +63,24 @@ class UniqueTest(test.TestCase):
|
||||
self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
|
||||
|
||||
def testInt32Axis(self):
|
||||
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
||||
with self.test_session() as sess:
|
||||
y0, idx0 = gen_array_ops.unique_v2(x, axis=[0])
|
||||
tf_y0, tf_idx0 = sess.run([y0, idx0])
|
||||
y1, idx1 = gen_array_ops.unique_v2(x, axis=[1])
|
||||
tf_y1, tf_idx1 = sess.run([y1, idx1])
|
||||
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
||||
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
||||
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
||||
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
||||
for dtype in [np.int32, np.int64]:
|
||||
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
||||
with self.test_session() as sess:
|
||||
y0, idx0 = gen_array_ops._unique_v2(x, axis=np.array([0], dtype))
|
||||
tf_y0, tf_idx0 = sess.run([y0, idx0])
|
||||
y1, idx1 = gen_array_ops._unique_v2(x, axis=np.array([1], dtype))
|
||||
tf_y1, tf_idx1 = sess.run([y1, idx1])
|
||||
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
||||
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
||||
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
||||
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
||||
|
||||
def testInt32V2(self):
|
||||
# This test is only temporary, once V2 is used
|
||||
# by default, the axis will be wrapped to allow `axis=None`.
|
||||
x = np.random.randint(2, high=10, size=7000)
|
||||
with self.test_session() as sess:
|
||||
y, idx = gen_array_ops.unique_v2(x, axis=[])
|
||||
y, idx = gen_array_ops._unique_v2(x, axis=np.array([], np.int32))
|
||||
tf_y, tf_idx = sess.run([y, idx])
|
||||
|
||||
self.assertEqual(len(x), len(tf_idx))
|
||||
|
@ -1279,6 +1279,17 @@ def sparse_mask(a, mask_indices, name=None):
|
||||
return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
|
||||
|
||||
|
||||
def unique(x, out_idx=dtypes.int32, name=None):
|
||||
# TODO(yongtang): switch to v2 once API deprecation
|
||||
# period (3 weeks) pass.
|
||||
# TODO(yongtang): The documentation should also
|
||||
# be updated when switch to v2.
|
||||
return gen_array_ops._unique(x, out_idx, name)
|
||||
|
||||
|
||||
unique.__doc__ = gen_array_ops._unique.__doc__
|
||||
|
||||
|
||||
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
||||
"""Splits a tensor into sub tensors.
|
||||
|
||||
@ -2032,7 +2043,7 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
|
||||
hypothesis = tf.SparseTensor(
|
||||
[[0, 0, 0],
|
||||
[1, 0, 0]],
|
||||
["a", "b"]
|
||||
["a", "b"],
|
||||
(2, 1, 1))
|
||||
|
||||
# 'truth' is a tensor of shape `[2, 2]` with variable-length values:
|
||||
@ -2044,7 +2055,7 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
|
||||
[[0, 1, 0],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0]]
|
||||
[1, 1, 0]],
|
||||
["a", "b", "c", "a"],
|
||||
(2, 2, 2))
|
||||
|
||||
|
@ -704,8 +704,8 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
|
||||
def testWarnings(self):
|
||||
# TODO(gunan) Reenable after this issue is fixed:
|
||||
# https://github.com/google/protobuf/issues/2812
|
||||
if sys.version_info >= (3, 6):
|
||||
self.skipTest("Skipped test for Python 3.6+")
|
||||
if sys.version_info >= (3, 5):
|
||||
self.skipTest("Skipped test for Python 3.5+")
|
||||
|
||||
# Smaller than the threshold: no warning.
|
||||
c_sparse = ops.IndexedSlices(
|
||||
|
@ -30,6 +30,8 @@ Squeeze
|
||||
Slice
|
||||
TileGrad # Exported through array_grad instead of array_ops.
|
||||
ZerosLike # TODO(josh11b): Use this instead of the Python version.
|
||||
Unique
|
||||
UniqueV2
|
||||
Unpack
|
||||
|
||||
# candidate_sampling_ops
|
||||
|
@ -219,7 +219,7 @@ def _SparseSegmentSqrtNGrad(op, grad):
|
||||
|
||||
@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
|
||||
def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
|
||||
"""Gradient for SparseSegmentSqrtNWithNumSegmnets."""
|
||||
"""Gradient for SparseSegmentSqrtNWithNumSegments."""
|
||||
dim0 = array_ops.shape(op.inputs[0])[0]
|
||||
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
|
||||
dim0), None, None, None)
|
||||
|
@ -1498,7 +1498,7 @@ def bias_add_v1(value, bias, name=None):
|
||||
return gen_nn_ops._bias_add_v1(value, bias, name=name)
|
||||
|
||||
|
||||
def crelu(features, name=None):
|
||||
def crelu(features, name=None, axis=-1):
|
||||
"""Computes Concatenated ReLU.
|
||||
|
||||
Concatenates a ReLU which selects only the positive part of the activation
|
||||
@ -1510,13 +1510,14 @@ def crelu(features, name=None):
|
||||
features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
|
||||
`int16`, or `int8`.
|
||||
name: A name for the operation (optional).
|
||||
axis: The axis that the output values are concatenated along. Default is -1.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with the same type as `features`.
|
||||
"""
|
||||
with ops.name_scope(name, "CRelu", [features]) as name:
|
||||
features = ops.convert_to_tensor(features, name="features")
|
||||
c = array_ops.concat([features, -features], -1, name=name)
|
||||
c = array_ops.concat([features, -features], axis, name=name)
|
||||
return gen_nn_ops.relu(c)
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
|
||||
kernel_class = pywrap_tensorflow.TryFindKernelClass(
|
||||
node_def.SerializeToString())
|
||||
if kernel_class:
|
||||
op_and_kernel = (str(node_def.op), kernel_class.decode('utf-8'))
|
||||
op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8')))
|
||||
if op_and_kernel not in ops:
|
||||
ops.add(op_and_kernel)
|
||||
else:
|
||||
|
@ -232,7 +232,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
||||
result = StringToDriverVersion(version);
|
||||
}
|
||||
#else
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
#if !defined(PLATFORM_WINDOWS) && !defined(NVIDIA_TEGRA)
|
||||
// Callback used when iterating through DSOs. Looks for the driver-interfacing
|
||||
// DSO and yields its version number into the callback data, when found.
|
||||
auto iterate_phdr =
|
||||
|
@ -340,8 +340,8 @@ class KernelArgIterator {
|
||||
//
|
||||
// This class exists as a way to pass kernel arguments to
|
||||
// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
|
||||
// be templated to accept any KernelArgsArray type, therfore a reference to this
|
||||
// base type is passed instead.
|
||||
// be templated to accept any KernelArgsArray type, therefore a reference to
|
||||
// this base type is passed instead.
|
||||
//
|
||||
// Performance is not a concern here because each of these methods will be
|
||||
// called at most once per kernel launch. Past performance concerns with
|
||||
|
@ -150,6 +150,12 @@ def if_darwin(a):
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def if_override_eigen_strong_inline(a):
|
||||
return select({
|
||||
clean_dep("//tensorflow:override_eigen_strong_inline"): a,
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def get_win_copts(is_external=False):
|
||||
WINDOWS_COPTS = [
|
||||
"/D__VERSION__=\\\"MSVC\\\"",
|
||||
@ -196,7 +202,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
|
||||
+ if_linux_x86_64(["-msse3"])
|
||||
+ if_ios_x86_64(["-msse4.1"])
|
||||
+ select({
|
||||
"//tensorflow:framework_shared_object": [],
|
||||
clean_dep("//tensorflow:framework_shared_object"): [],
|
||||
"//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"],
|
||||
})
|
||||
+ select({
|
||||
@ -922,7 +928,8 @@ def tf_kernel_library(name,
|
||||
if not deps:
|
||||
deps = []
|
||||
if not copts:
|
||||
copts = tf_copts(is_external=is_external)
|
||||
copts = []
|
||||
copts = copts + tf_copts(is_external=is_external)
|
||||
if prefix:
|
||||
if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]):
|
||||
if not gpu_srcs:
|
||||
|
@ -86,7 +86,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "crelu"
|
||||
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'features\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ctc_beam_search_decoder"
|
||||
|
@ -102,9 +102,6 @@ function run_configure_for_cpu_build {
|
||||
if [ -z "$TF_ENABLE_XLA" ]; then
|
||||
export TF_ENABLE_XLA=0
|
||||
fi
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
export CC_OPT_FLAGS="-march=native"
|
||||
fi
|
||||
if [ -z "$TF_NEED_MKL" ]; then
|
||||
export TF_NEED_MKL=0
|
||||
fi
|
||||
@ -120,17 +117,14 @@ function run_configure_for_gpu_build {
|
||||
# yes "" | ./configure doesn't work on Windows, so we set all the
|
||||
# environment variables in advance to avoid interact with the script.
|
||||
export TF_NEED_CUDA=1
|
||||
export TF_CUDA_VERSION=8.0
|
||||
export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0"
|
||||
export TF_CUDNN_VERSION=6.0
|
||||
export TF_CUDA_VERSION=9.0
|
||||
export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0"
|
||||
export TF_CUDNN_VERSION=7.0
|
||||
export CUDNN_INSTALL_PATH="C:/tools/cuda"
|
||||
export TF_CUDA_COMPUTE_CAPABILITIES="3.7"
|
||||
if [ -z "$TF_ENABLE_XLA" ]; then
|
||||
export TF_ENABLE_XLA=0
|
||||
fi
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
export CC_OPT_FLAGS="-march=native"
|
||||
fi
|
||||
export TF_NEED_VERBS=0
|
||||
export TF_NEED_MKL=0
|
||||
export TF_NEED_GCP=0
|
||||
|
@ -46,6 +46,6 @@ export PATH="/c/${PYTHON_BASE_PATH}:$PATH"
|
||||
export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
|
||||
|
||||
# Add Cuda and Cudnn dll directories into PATH
|
||||
export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0/bin:$PATH"
|
||||
export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0/extras/CUPTI/libx64:$PATH"
|
||||
export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH"
|
||||
export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH"
|
||||
export PATH="/c/tools/cuda/bin:$PATH"
|
||||
|
@ -44,7 +44,10 @@ source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \
|
||||
|
||||
run_configure_for_cpu_build
|
||||
|
||||
bazel build -c opt tensorflow/tools/pip_package:build_pip_package || exit $?
|
||||
# --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
|
||||
# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
|
||||
BUILD_OPTS="--define=override_eigen_strong_inline=true"
|
||||
bazel build -c opt $BUILD_OPTS tensorflow/tools/pip_package:build_pip_package || exit $?
|
||||
|
||||
# Create a python test directory to avoid package name conflict
|
||||
PY_TEST_DIR="py_test_dir"
|
||||
@ -58,7 +61,7 @@ reinstall_tensorflow_pip ${PIP_NAME}
|
||||
|
||||
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
|
||||
# which will result testing system installed tensorflow
|
||||
bazel test -c opt -k --test_output=errors \
|
||||
bazel test -c opt $BUILD_OPTS -k --test_output=errors \
|
||||
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
|
||||
--test_tag_filters=-no_pip,-no_windows,-no_oss \
|
||||
--build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \
|
||||
|
@ -0,0 +1 @@
|
||||
c:\tools\msys64\usr\bin\bash -l %cd%/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh %*
|
@ -31,14 +31,6 @@ if [ ! -e "WORKSPACE" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Enable JNI support for Windows in Bazel.
|
||||
# This can be removed once
|
||||
# https://github.com/bazelbuild/bazel/pull/2599
|
||||
# has been merged and we switch to a bazel release containing it.
|
||||
cp "${JAVA_HOME}/include/win32/jni_md.h" "./tensorflow/java/src/main/native/windows_jni_md.h"
|
||||
sed -i -e "s|@bazel_tools//tools/jdk:jni_md_header-linux|windows_jni_md.h|" ./tensorflow/java/src/main/native/BUILD
|
||||
#### END HACKS TO BE RESOLVED WITH NEW BAZEL VERSIONS ####
|
||||
|
||||
export TF_BAZEL_TARGETS="//tensorflow:libtensorflow.so"
|
||||
export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/tools/lib_package:clicenses_generate"
|
||||
export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/java:libtensorflow_jni.so"
|
||||
@ -55,11 +47,6 @@ bazel build -c opt \
|
||||
tensorflow/java:libtensorflow_jni.so \
|
||||
tensorflow/tools/lib_package:jnilicenses_generate
|
||||
|
||||
# Revert the hacks above
|
||||
git checkout ./tensorflow/tools/pip_package/BUILD
|
||||
git checkout ./tensorflow/java/src/main/native/BUILD
|
||||
rm -f ./tensorflow/java/src/main/native/windows_jni_md.h
|
||||
|
||||
DIR=lib_package
|
||||
rm -rf ${DIR}
|
||||
mkdir -p ${DIR}
|
||||
@ -80,7 +67,7 @@ cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
|
||||
cp tensorflow/c/eager/c_api.h ${DIR}/include/tensorflow/c/eager
|
||||
cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
|
||||
cd ${DIR}
|
||||
zip -j libtensorflow-cpu-windows-$(uname -m).zip \
|
||||
zip libtensorflow-cpu-windows-$(uname -m).zip \
|
||||
lib/tensorflow.dll \
|
||||
include/tensorflow/c/eager/c_api.h \
|
||||
include/tensorflow/c/c_api.h \
|
||||
|
72
tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
Normal file
72
tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
Normal file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#
|
||||
# Script to produce binary release of libtensorflow (C API, Java jars etc.).
|
||||
|
||||
set -ex
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Setup environment for bazel builds
|
||||
source "${SCRIPT_DIR}/bazel/common_env.sh"
|
||||
source "${SCRIPT_DIR}/bazel/bazel_test_lib.sh"
|
||||
|
||||
# Sanity check that this is being run from the root of the git repository.
|
||||
cd ${SCRIPT_DIR}/../../../..
|
||||
if [ ! -e "WORKSPACE" ]; then
|
||||
echo "Must run this from the root of the bazel workspace"
|
||||
echo "Currently at ${PWD}, script is at ${SCRIPT_DIR}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export TF_BAZEL_TARGETS="//tensorflow:libtensorflow.so"
|
||||
export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/tools/lib_package:clicenses_generate"
|
||||
export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/java:libtensorflow_jni.so"
|
||||
export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/tools/lib_package:jnilicenses_generate"
|
||||
|
||||
run_configure_for_gpu_build
|
||||
|
||||
# build_libtensorflow_tarball in ../builds/libtensorflow.sh
|
||||
# cannot be used on Windows since it relies on pkg_tar rules.
|
||||
# So we do something special here
|
||||
bazel build -c opt \
|
||||
tensorflow:libtensorflow.so \
|
||||
tensorflow/tools/lib_package:clicenses_generate \
|
||||
tensorflow/java:libtensorflow_jni.so \
|
||||
tensorflow/tools/lib_package:jnilicenses_generate
|
||||
|
||||
DIR=lib_package
|
||||
rm -rf ${DIR}
|
||||
mkdir -p ${DIR}
|
||||
|
||||
# Zip up the .dll and the LICENSE for the JNI library.
|
||||
cp bazel-bin/tensorflow/java/libtensorflow_jni.so ${DIR}/tensorflow_jni.dll
|
||||
zip -j ${DIR}/libtensorflow_jni-gpu-windows-$(uname -m).zip \
|
||||
${DIR}/tensorflow_jni.dll \
|
||||
bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/jni/LICENSE
|
||||
rm -f ${DIR}/tensorflow_jni.dll
|
||||
|
||||
# Zip up the .dll, LICENSE and include files for the C library.
|
||||
mkdir -p ${DIR}/include/tensorflow/c
|
||||
mkdir -p ${DIR}/lib
|
||||
cp bazel-bin/tensorflow/libtensorflow.so ${DIR}/lib/tensorflow.dll
|
||||
cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
|
||||
cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
|
||||
cd ${DIR}
|
||||
zip -j libtensorflow-gpu-windows-$(uname -m).zip \
|
||||
lib/tensorflow.dll \
|
||||
include/tensorflow/c/c_api.h \
|
||||
include/tensorflow/c/LICENSE
|
||||
rm -rf lib include
|
@ -10,6 +10,7 @@ load(
|
||||
"transitive_hdrs",
|
||||
)
|
||||
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
||||
load("//tensorflow:tensorflow.bzl", "if_cuda")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||
|
||||
# This returns a list of headers of all public header libraries (e.g.,
|
||||
@ -34,7 +35,9 @@ transitive_hdrs(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
] + if_cuda([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]),
|
||||
)
|
||||
|
||||
py_binary(
|
||||
@ -169,6 +172,7 @@ sh_binary(
|
||||
"//tensorflow/contrib/ndlstm:ndlstm",
|
||||
"//tensorflow/contrib/nn:nn_py",
|
||||
"//tensorflow/contrib/predictor:predictor_pip",
|
||||
"//tensorflow/contrib/py2tf:py2tf_internal",
|
||||
"//tensorflow/contrib/py2tf/convert:convert",
|
||||
"//tensorflow/contrib/py2tf/pyct:pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
|
||||
|
@ -27,6 +27,8 @@ function cp_external() {
|
||||
for f in `find "$src_dir" -maxdepth 1 -mindepth 1 ! -name '*local_config_cuda*' ! -name '*org_tensorflow*'`; do
|
||||
cp -R "$f" "$dest_dir"
|
||||
done
|
||||
mkdir -p "${dest_dir}/local_config_cuda/cuda/cuda/"
|
||||
cp "${src_dir}/local_config_cuda/cuda/cuda/cuda_config.h" "${dest_dir}/local_config_cuda/cuda/cuda/"
|
||||
}
|
||||
|
||||
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
|
||||
|
1
third_party/astor.BUILD
vendored
1
third_party/astor.BUILD
vendored
@ -19,5 +19,6 @@ py_library(
|
||||
"astor/string_repr.py",
|
||||
"astor/tree_walk.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
1
third_party/gast.BUILD
vendored
1
third_party/gast.BUILD
vendored
@ -14,5 +14,6 @@ py_library(
|
||||
"gast/astn.py",
|
||||
"gast/gast.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
1
third_party/termcolor.BUILD
vendored
1
third_party/termcolor.BUILD
vendored
@ -10,5 +10,6 @@ py_library(
|
||||
srcs = [
|
||||
"termcolor.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user