Merge changes from github.
PiperOrigin-RevId: 192850372
This commit is contained in:
parent
ef24ad1450
commit
3652556dab
@ -450,11 +450,12 @@ tf_cc_shared_object(
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core:core_cpu_impl",
|
||||
"//tensorflow/stream_executor:stream_executor_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/stream_executor:stream_executor_impl",
|
||||
] + tf_additional_binary_deps(),
|
||||
)
|
||||
|
||||
|
@ -318,6 +318,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
@ -441,6 +442,9 @@ string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
|
||||
}
|
||||
|
||||
auto node_name = [&cycles, &graph](int node_id) {
|
||||
if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
|
||||
return string("(null)");
|
||||
}
|
||||
auto* node = graph.FindNodeId(node_id);
|
||||
if (node == nullptr) {
|
||||
return string("(null)");
|
||||
|
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
1
tensorflow/contrib/cmake/external/grpc.cmake
vendored
@ -35,6 +35,7 @@ else()
|
||||
set(grpc_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
||||
endif()
|
||||
|
@ -201,7 +201,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
|
||||
#An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
|
||||
#stores String-based info such as name, device and type of the op.
|
||||
#Unique to every Operation instance.
|
||||
new_node_def = deepcopy(op._node_def)
|
||||
new_node_def = deepcopy(op.node_def)
|
||||
#Change the name
|
||||
new_node_def.name = new_name
|
||||
|
||||
@ -211,7 +211,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
|
||||
|
||||
#Make a copy of the op_def too.
|
||||
#Its unique to every _type_ of Operation.
|
||||
op_def = deepcopy(op._op_def)
|
||||
op_def = deepcopy(op.op_def)
|
||||
|
||||
#Initialize a new Operation instance
|
||||
new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types,
|
||||
|
@ -25,6 +25,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
|
||||
@@Counter
|
||||
@@SqlDataset
|
||||
|
||||
@@assert_element_shape
|
||||
@@batch_and_drop_remainder
|
||||
@@bucket_by_sequence_length
|
||||
@@dense_to_sparse_batch
|
||||
@ -55,6 +56,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
|
||||
from tensorflow.contrib.data.python.ops.batching import assert_element_shape
|
||||
from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
|
||||
from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
|
||||
from tensorflow.contrib.data.python.ops.batching import map_and_batch
|
||||
|
@ -21,6 +21,7 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
|
@ -28,8 +28,10 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -579,5 +581,73 @@ class PaddedBatchDatasetSerializationTest(
|
||||
lambda: build_dataset(seq_lens2), 8)
|
||||
|
||||
|
||||
class RestructuredDatasetTest(test.TestCase):
|
||||
|
||||
def test_assert_element_shape(self):
|
||||
|
||||
def create_unknown_shape_dataset(x):
|
||||
return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
|
||||
np.zeros((3, 4), dtype=np.int32)),
|
||||
[x],
|
||||
[dtypes.float32, dtypes.int32])
|
||||
|
||||
dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
|
||||
unknown_shapes = (tensor_shape.TensorShape(None),
|
||||
tensor_shape.TensorShape(None))
|
||||
self.assertEqual(unknown_shapes, dataset.output_shapes)
|
||||
|
||||
expected_shapes = (tensor_shape.TensorShape(2),
|
||||
tensor_shape.TensorShape((3, 4)))
|
||||
result = dataset.apply(batching.assert_element_shape(expected_shapes))
|
||||
self.assertEqual(expected_shapes, result.output_shapes)
|
||||
|
||||
iterator = result.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(5):
|
||||
sess.run(get_next)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def test_assert_wrong_element_shape(self):
|
||||
|
||||
def create_dataset(_):
|
||||
return (array_ops.ones(2, dtype=dtypes.float32),
|
||||
array_ops.zeros((3, 4), dtype=dtypes.int32))
|
||||
|
||||
dataset = dataset_ops.Dataset.range(3).map(create_dataset)
|
||||
wrong_shapes = (tensor_shape.TensorShape(2),
|
||||
tensor_shape.TensorShape((3, 10)))
|
||||
with self.assertRaises(ValueError):
|
||||
dataset.apply(batching.assert_element_shape(wrong_shapes))
|
||||
|
||||
def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
|
||||
|
||||
def create_unknown_shape_dataset(x):
|
||||
return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
|
||||
np.zeros((3, 4), dtype=np.int32)),
|
||||
[x],
|
||||
[dtypes.float32, dtypes.int32])
|
||||
|
||||
dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
|
||||
unknown_shapes = (tensor_shape.TensorShape(None),
|
||||
tensor_shape.TensorShape(None))
|
||||
self.assertEqual(unknown_shapes, dataset.output_shapes)
|
||||
|
||||
wrong_shapes = (tensor_shape.TensorShape(2),
|
||||
tensor_shape.TensorShape((3, 10)))
|
||||
iterator = (
|
||||
dataset.apply(batching.assert_element_shape(wrong_shapes))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -47,6 +47,11 @@ class SequenceDatasetSerializationTest(
|
||||
# Skip nothing
|
||||
self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
|
||||
|
||||
def testInvalidSkip(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
|
||||
|
||||
def _build_take_dataset(self, count):
|
||||
components = (np.arange(10),)
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).take(count)
|
||||
@ -69,6 +74,11 @@ class SequenceDatasetSerializationTest(
|
||||
# Take nothing
|
||||
self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
|
||||
|
||||
def testInvalidTake(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
|
||||
|
||||
def _build_repeat_dataset(self, count, take_count=3):
|
||||
components = (np.arange(10),)
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).take(
|
||||
|
@ -112,6 +112,7 @@ py_library(
|
||||
srcs = ["batching.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import with_shape
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
@ -345,6 +346,46 @@ class _RestructuredDataset(dataset_ops.Dataset):
|
||||
return self._output_shapes
|
||||
|
||||
|
||||
def assert_element_shape(expected_shapes):
|
||||
"""Assert the shape of this `Dataset`.
|
||||
|
||||
```python
|
||||
shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
|
||||
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
|
||||
print(result.output_shapes) # ==> "((16, 256), <unknown>)"
|
||||
```
|
||||
|
||||
If dataset shapes and expected_shape, are fully defined, assert they match.
|
||||
Otherwise, add assert op that will validate the shapes when tensors are
|
||||
evaluated, and set shapes on tensors, respectively.
|
||||
|
||||
Args:
|
||||
expected_shapes: A nested structure of `tf.TensorShape` objects.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}
|
||||
"""
|
||||
|
||||
def _check_shape(*elements):
|
||||
flatten_tensors = nest.flatten(elements)
|
||||
flatten_shapes = nest.flatten(expected_shapes)
|
||||
checked_tensors = [
|
||||
with_shape(shape, tensor)
|
||||
for shape, tensor in zip(flatten_shapes, flatten_tensors)
|
||||
]
|
||||
return nest.pack_sequence_as(elements, checked_tensors)
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _RestructuredDataset(
|
||||
dataset.map(_check_shape),
|
||||
dataset.output_types,
|
||||
output_shapes=expected_shapes,
|
||||
output_classes=dataset.output_classes)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
class _MapAndBatchDataset(dataset_ops.MapDataset):
|
||||
"""A `Dataset` that maps a function over a batch of elements."""
|
||||
|
||||
|
@ -73,7 +73,7 @@ class DistributedValues(object):
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
return self._index.keys()
|
||||
return list(self._index.keys())
|
||||
|
||||
def __str__(self):
|
||||
return "%s:%s" % (self.__class__.__name__, self._index)
|
||||
|
@ -43,10 +43,10 @@ def sparse_multiclass_hinge_loss(
|
||||
|
||||
This is a generalization of standard (binary) hinge loss. For a given instance
|
||||
with correct label c*, the loss is given by:
|
||||
loss = max_{c != c*} logits_c - logits_{c*} + 1.
|
||||
$$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$
|
||||
or equivalently
|
||||
loss = max_c { logits_c - logits_{c*} + I_{c != c*} }
|
||||
where I_{c != c*} = 1 if c != c* and 0 otherwise.
|
||||
$$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$
|
||||
where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise.
|
||||
|
||||
Args:
|
||||
labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to
|
||||
|
@ -34,33 +34,31 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper):
|
||||
r"""Class that implements Random Fourier Feature Mapping (RFFM) in TensorFlow.
|
||||
|
||||
The RFFM mapping is used to approximate the Gaussian (RBF) kernel:
|
||||
```
|
||||
exp(-||x-y||_2^2 / (2 * sigma^2))
|
||||
```
|
||||
$$(exp(-||x-y||_2^2 / (2 * \sigma^2))$$
|
||||
|
||||
The implementation of RFFM is based on the following paper:
|
||||
"Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht.
|
||||
(link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf)
|
||||
|
||||
The mapping uses a matrix `Omega \in R^{d x D}` and a bias vector `b \in R^D`
|
||||
where `d` is the input dimension (number of dense input features) and `D` is
|
||||
the output dimension (i.e., dimension of the feature space the input is mapped
|
||||
to). Each entry of `Omega` is sampled i.i.d. from a (scaled) Gaussian
|
||||
distribution and each entry of `b` is sampled independently and uniformly from
|
||||
[0, 2 * pi].
|
||||
The mapping uses a matrix \\(\Omega \in R^{d x D}\\) and a bias vector
|
||||
\\(b \in R^D\\) where \\(d\\) is the input dimension (number of dense input
|
||||
features) and \\(D\\) is the output dimension (i.e., dimension of the feature
|
||||
space the input is mapped to). Each entry of \\(\Omega\\) is sampled i.i.d.
|
||||
from a (scaled) Gaussian distribution and each entry of \\(b\\) is sampled
|
||||
independently and uniformly from [0, \\(2 * \pi\\)].
|
||||
|
||||
For a single input feature vector x in R^d, its RFFM is defined as:
|
||||
```
|
||||
sqrt(2/D) * cos(x * Omega + b)
|
||||
```
|
||||
where `cos` is the element-wise cosine function and `x, b` are represented as
|
||||
row vectors. The aforementioned paper shows that the linear kernel of
|
||||
RFFM-mapped vectors approximates the Gaussian kernel of the initial vectors.
|
||||
For a single input feature vector \\(x \in R^d\\), its RFFM is defined as:
|
||||
$$\sqrt(2/D) * cos(x * \Omega + b)$$
|
||||
|
||||
where \\(cos\\) is the element-wise cosine function and \\(x, b\\) are
|
||||
represented as row vectors. The aforementioned paper shows that the linear
|
||||
kernel of RFFM-mapped vectors approximates the Gaussian kernel of the initial
|
||||
vectors.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, stddev=1.0, seed=1, name=None):
|
||||
"""Constructs a RandomFourierFeatureMapper instance.
|
||||
r"""Constructs a RandomFourierFeatureMapper instance.
|
||||
|
||||
Args:
|
||||
input_dim: The dimension (number of features) of the tensors to be mapped.
|
||||
@ -68,11 +66,11 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper):
|
||||
stddev: The standard deviation of the Gaussian kernel to be approximated.
|
||||
The error of the classifier trained using this approximation is very
|
||||
sensitive to this parameter.
|
||||
seed: An integer used to initialize the parameters (`Omega` and `b`) of
|
||||
the mapper. For repeatable sequences across different invocations of the
|
||||
mapper object (for instance, to ensure consistent mapping both at
|
||||
training and eval/inference if these happen in different invocations),
|
||||
set this to the same integer.
|
||||
seed: An integer used to initialize the parameters (\\(\Omega\\) and
|
||||
\\(b\\)) of the mapper. For repeatable sequences across different
|
||||
invocations of the mapper object (for instance, to ensure consistent
|
||||
mapping both at training and eval/inference if these happen in
|
||||
different invocations), set this to the same integer.
|
||||
name: name for the mapper object.
|
||||
"""
|
||||
# TODO(sibyl-vie3Poto): Maybe infer input_dim and/or output_dim (if not explicitly
|
||||
|
@ -34,7 +34,7 @@ def _inner_product(x, y):
|
||||
"""Inner product between tensors x and y.
|
||||
|
||||
The input tensors are assumed to be in ROW representation, that is, the method
|
||||
returns x * y^T.
|
||||
returns \\(x * y^T\\).
|
||||
|
||||
Args:
|
||||
x: input tensor in row format
|
||||
|
@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior
|
||||
distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
|
||||
Fisher Information matrix is given by,
|
||||
|
||||
F(params) = E[ v(x, y, params) v(x, y, params)^T ]
|
||||
$$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
|
||||
|
||||
where,
|
||||
|
||||
v(x, y, params) = (d / d params) log p(y | x, params)
|
||||
$$v(x, y, params) = (d / d params) log p(y | x, params)$$
|
||||
|
||||
and the expectation is taken with respect to the data's distribution for 'x' and
|
||||
the model's posterior distribution for 'y',
|
||||
@ -85,7 +85,7 @@ def normalize_damping(damping, num_replications):
|
||||
def compute_pi_tracenorm(left_cov, right_cov):
|
||||
"""Computes the scalar constant pi for Tikhonov regularization/damping.
|
||||
|
||||
pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
|
||||
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
|
||||
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
|
||||
|
||||
Args:
|
||||
@ -462,14 +462,14 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
|
||||
into it. We are interested in Fisher(params)[i, i]. This is,
|
||||
|
||||
Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
|
||||
= E[ v(x, y, params)[i] ^ 2 ]
|
||||
$$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
|
||||
= E[ v(x, y, params)[i] ^ 2 ]$$
|
||||
|
||||
Consider fully connected layer in this model with (unshared) weight matrix
|
||||
'w'. For an example 'x' that produces layer inputs 'a' and output
|
||||
preactivations 's',
|
||||
|
||||
v(x, y, w) = vec( a (d loss / d s)^T )
|
||||
$$v(x, y, w) = vec( a (d loss / d s)^T )$$
|
||||
|
||||
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
|
||||
to the layer's parameters 'w'.
|
||||
@ -532,14 +532,14 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
|
||||
into it. We are interested in Fisher(params)[i, i]. This is,
|
||||
|
||||
Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
|
||||
= E[ v(x, y, params)[i] ^ 2 ]
|
||||
$$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
|
||||
= E[ v(x, y, params)[i] ^ 2 ]$$
|
||||
|
||||
Consider a convoluational layer in this model with (unshared) filter matrix
|
||||
'w'. For an example image 'x' that produces layer inputs 'a' and output
|
||||
preactivations 's',
|
||||
|
||||
v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )
|
||||
$$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
|
||||
|
||||
where 'loc' is a single (x, y) location in an image.
|
||||
|
||||
@ -805,12 +805,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
||||
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
|
||||
this FisherBlock estimates,
|
||||
|
||||
F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
|
||||
E[flat(ds) flat(ds)^T])
|
||||
$$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
|
||||
E[flat(ds) flat(ds)^T])$$
|
||||
|
||||
where
|
||||
|
||||
ds = (d / ds) log p(y | x, w)
|
||||
$$ds = (d / ds) log p(y | x, w)$$
|
||||
#locations = number of (x, y) locations where 'w' is applied.
|
||||
|
||||
where the expectation is taken over all examples and locations and flat()
|
||||
@ -1567,7 +1567,7 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
||||
|
||||
if self._option == SeriesFBApproximation.option1:
|
||||
|
||||
# Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G.
|
||||
# Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
|
||||
L_A, psi_A = self._input_factor.get_option1quants(
|
||||
self._input_damping_func)
|
||||
L_G, psi_G = self._output_factor.get_option1quants(
|
||||
@ -1581,33 +1581,33 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
||||
T = self._num_timesteps
|
||||
return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
|
||||
|
||||
# Y = gamma( psi_G*psi_A^T ) (computed element-wise)
|
||||
# \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
|
||||
# Even though Y is Z-independent we are recomputing it from the psi's
|
||||
# each since Y depends on both A and G quantities, and it is relatively
|
||||
# cheap to compute.
|
||||
Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
|
||||
|
||||
# Z = L_G^T * Z * L_A
|
||||
# \\(Z = L_G^T * Z * L_A\\)
|
||||
# This is equivalent to the following computation from the original
|
||||
# pseudo-code:
|
||||
# Z = G0^(-1/2) * Z * A0^(-1/2)
|
||||
# Z = U_G^T * Z * U_A
|
||||
# \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
|
||||
# \\(Z = U_G^T * Z * U_A\\)
|
||||
Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
|
||||
|
||||
# Z = Z .* Y
|
||||
# \\(Z = Z .* Y\\)
|
||||
Z *= Y
|
||||
|
||||
# Z = L_G * Z * L_A^T
|
||||
# \\(Z = L_G * Z * L_A^T\\)
|
||||
# This is equivalent to the following computation from the original
|
||||
# pseudo-code:
|
||||
# Z = U_G * Z * U_A^T
|
||||
# Z = G0^(-1/2) * Z * A0^(-1/2)
|
||||
# \\(Z = U_G * Z * U_A^T\\)
|
||||
# \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
|
||||
Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
|
||||
|
||||
elif self._option == SeriesFBApproximation.option2:
|
||||
|
||||
# Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1),
|
||||
# and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G.
|
||||
# Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
|
||||
# and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
|
||||
P_A, K_A, mu_A = self._input_factor.get_option2quants(
|
||||
self._input_damping_func)
|
||||
P_G, K_G, mu_G = self._output_factor.get_option2quants(
|
||||
@ -1616,26 +1616,26 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
||||
# Our approach differs superficially from the pseudo-code in the paper
|
||||
# in order to reduce the total number of matrix-matrix multiplies.
|
||||
# In particular, the first three computations in the pseudo code are
|
||||
# Z = G0^(-1/2) * Z * A0^(-1/2)
|
||||
# Z = Z - hPsi_G^T * Z * hPsi_A
|
||||
# Z = E_G^T * Z * E_A
|
||||
# Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that
|
||||
# C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2)
|
||||
# \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
|
||||
# \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
|
||||
# \\(Z = E_G^T * Z * E_A\\)
|
||||
# Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
|
||||
# \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
|
||||
# the entire computation can be written as
|
||||
# Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2)
|
||||
# - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A
|
||||
# = E_G^T * (G0^(-1/2) * Z * A0^(-1/2)
|
||||
# - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A
|
||||
# = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A
|
||||
# - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A
|
||||
# = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A
|
||||
# \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
|
||||
# \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
|
||||
# \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
|
||||
# \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
|
||||
# \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
|
||||
# \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
|
||||
# \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
|
||||
# This final expression is computed by the following two lines:
|
||||
# Z = Z - P_G * Z * P_A^T
|
||||
# \\(Z = Z - P_G * Z * P_A^T\\)
|
||||
Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
|
||||
# Z = K_G^T * Z * K_A
|
||||
# \\(Z = K_G^T * Z * K_A\\)
|
||||
Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
|
||||
|
||||
# Z = Z ./ (1*1^T - mu_G*mu_A^T)
|
||||
# \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
|
||||
# Be careful with the outer product. We don't want to accidentally
|
||||
# make it an inner-product instead.
|
||||
tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
|
||||
@ -1646,13 +1646,13 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
||||
# We now perform the transpose/reverse version of the operations
|
||||
# derived above, whose derivation from the original pseudo-code is
|
||||
# analgous.
|
||||
# Z = K_G * Z * K_A^T
|
||||
# \\(Z = K_G * Z * K_A^T\\)
|
||||
Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
|
||||
|
||||
# Z = Z - P_G^T * Z * P_A
|
||||
# \\(Z = Z - P_G^T * Z * P_A\\)
|
||||
Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
|
||||
|
||||
# Z = normalize (1/E[T]) * Z
|
||||
# \\(Z = normalize (1/E[T]) * Z\\)
|
||||
# Note that this normalization is done because we compute the statistics
|
||||
# by averaging, not summing, over time. (And the gradient is presumably
|
||||
# summed over time, not averaged, and thus their scales are different.)
|
||||
|
@ -19,11 +19,16 @@ set -e
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR/../../.."
|
||||
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \
|
||||
$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \
|
||||
$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \
|
||||
$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \
|
||||
$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a
|
||||
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \
|
||||
$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a
|
||||
|
||||
lipo \
|
||||
tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \
|
||||
|
@ -63,6 +63,8 @@ def _safe_div(numerator, denominator, name):
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
def streaming_true_positives(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -107,6 +109,8 @@ def streaming_true_positives(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.true_negatives. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
def streaming_true_negatives(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -151,6 +155,8 @@ def streaming_true_negatives(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.false_positives. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
def streaming_false_positives(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -195,6 +201,8 @@ def streaming_false_positives(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.false_negatives. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
def streaming_false_negatives(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -238,6 +246,7 @@ def streaming_false_negatives(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.mean')
|
||||
def streaming_mean(values,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
@ -287,6 +296,7 @@ def streaming_mean(values,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.mean_tensor')
|
||||
def streaming_mean_tensor(values,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
@ -340,9 +350,8 @@ def streaming_mean_tensor(values,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None,
|
||||
'Please switch to tf.metrics.accuracy. Note that the order of the '
|
||||
'labels and predictions arguments has been switched.')
|
||||
@deprecated(None, 'Please switch to tf.metrics.accuracy. Note that the order '
|
||||
'of the labels and predictions arguments has been switched.')
|
||||
def streaming_accuracy(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -400,6 +409,8 @@ def streaming_accuracy(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.precision. Note that the order '
|
||||
'of the labels and predictions arguments has been switched.')
|
||||
def streaming_precision(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -456,6 +467,8 @@ def streaming_precision(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.recall. Note that the order '
|
||||
'of the labels and predictions arguments has been switched.')
|
||||
def streaming_recall(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -975,8 +988,8 @@ def streaming_curve_points(labels=None,
|
||||
return points, update_op
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the '
|
||||
'labels and predictions arguments has been switched.')
|
||||
@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of '
|
||||
'the labels and predictions arguments has been switched.')
|
||||
def streaming_auc(predictions,
|
||||
labels,
|
||||
weights=None,
|
||||
@ -1797,9 +1810,9 @@ def streaming_sensitivity_at_specificity(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
@deprecated(
|
||||
None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
@deprecated(None,
|
||||
'Please switch to tf.metrics.precision_at_thresholds. Note that '
|
||||
'the order of the labels and predictions arguments are switched.')
|
||||
def streaming_precision_at_thresholds(predictions,
|
||||
labels,
|
||||
thresholds,
|
||||
|
@ -2891,7 +2891,7 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
|
||||
|
||||
output_size = weight.get_shape().as_list()[1]
|
||||
g = vs.get_variable(name, [output_size], dtype=weight.dtype)
|
||||
return nn_impl.l2_normalize(weight, dim=0) * g
|
||||
return nn_impl.l2_normalize(weight, axis=0) * g
|
||||
|
||||
def _linear(self,
|
||||
args,
|
||||
|
@ -610,8 +610,8 @@ def monotonic_attention(p_choose_i, previous_attention, mode):
|
||||
addition, once an input sequence element is attended to at a given output
|
||||
timestep, elements occurring before it cannot be attended to at subsequent
|
||||
output timesteps. This function generates attention distributions according
|
||||
to these assumptions. For more information, see ``Online and Linear-Time
|
||||
Attention by Enforcing Monotonic Alignments''.
|
||||
to these assumptions. For more information, see `Online and Linear-Time
|
||||
Attention by Enforcing Monotonic Alignments`.
|
||||
|
||||
Args:
|
||||
p_choose_i: Probability of choosing input sequence/memory element i. Should
|
||||
|
@ -14,7 +14,7 @@
|
||||
# ==============================================================================
|
||||
"""Module that implements sparsemax and sparsemax loss, see [1].
|
||||
|
||||
[1] https://arxiv.org/abs/1602.02068
|
||||
[1]: https://arxiv.org/abs/1602.02068
|
||||
|
||||
## Sparsemax
|
||||
|
||||
|
@ -31,7 +31,7 @@ def sparsemax(logits, name=None):
|
||||
"""Computes sparsemax activations [1].
|
||||
|
||||
For each batch `i` and class `j` we have
|
||||
sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)
|
||||
$$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$
|
||||
|
||||
[1]: https://arxiv.org/abs/1602.02068
|
||||
|
||||
|
@ -405,7 +405,13 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
max_mem_per_engine, static_graph_properties,
|
||||
&output_edge_map, precision_mode);
|
||||
if (precision_mode == INT8MODE) {
|
||||
TF_RETURN_IF_ERROR(GetCalibNode(&p));
|
||||
tensorflow::Status status = GetCalibNode(&p);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
|
||||
<< " due to: \"" << status.ToString()
|
||||
<< "\" SKIPPING......( " << subgraph_node_names.size()
|
||||
<< " nodes)";
|
||||
}
|
||||
} else {
|
||||
tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
@ -414,8 +420,8 @@ tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
<< "\" SKIPPING......( " << subgraph_node_names.size()
|
||||
<< " nodes)";
|
||||
}
|
||||
count++;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
graph.ToGraphDef(new_graph_def);
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -443,7 +443,9 @@ class Converter {
|
||||
* 2) Control dependency inputs contain caret at the beginning and we
|
||||
* remove this and annotate the edge as a control dependency.
|
||||
************************************************************************/
|
||||
string name = input_name[0] == '^' ? input_name.substr(1) : input_name;
|
||||
// skip control nodes
|
||||
if (input_name[0] == '^') continue;
|
||||
string name = input_name;
|
||||
auto first = name.find_first_of(':');
|
||||
if (first != string::npos && first + 2 == name.size() &&
|
||||
name[first + 1] == '0')
|
||||
@ -2262,6 +2264,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
auto ws = new tensorflow::tensorrt::TRTWeightStore();
|
||||
TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
|
||||
Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE);
|
||||
|
||||
std::vector<string> input_names;
|
||||
std::vector<tensorflow::DataType> input_dtypes;
|
||||
for (const std::pair<int, int>& input : s.input_inds) {
|
||||
@ -2270,20 +2273,41 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
int output_idx = input.second;
|
||||
tensorflow::Node* node = s.graph.FindNodeId(node_id);
|
||||
auto node_name = node->name();
|
||||
input_names.push_back(node_name); // insert original node name without port
|
||||
// TODO(jie): alternative :)
|
||||
if (!s.graph_properties.HasOutputProperties(node_name))
|
||||
// input_names should use the node name in the graph
|
||||
// here it should be the input tensor name -> matching the binding
|
||||
// insert original node name without port
|
||||
auto tensor_name = node_name;
|
||||
if (output_idx != 0) {
|
||||
tensor_name = StrCat(tensor_name, ":", output_idx);
|
||||
}
|
||||
|
||||
VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
|
||||
<< " idx: " << output_idx;
|
||||
|
||||
auto shape_inference_node_name = node_name;
|
||||
auto shape_inference_output_idx = output_idx;
|
||||
// rewire the shape inference to original node in the graph
|
||||
if (s.output_edge_map->count(tensor_name)) {
|
||||
shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
|
||||
shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
|
||||
}
|
||||
if (shape_inference_output_idx < 0) continue;
|
||||
VLOG(2) << "shapeinference name: " << shape_inference_node_name
|
||||
<< " idx: " << shape_inference_output_idx;
|
||||
|
||||
if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
|
||||
return tensorflow::errors::Internal("failed to find input node: " +
|
||||
node_name);
|
||||
shape_inference_node_name);
|
||||
|
||||
auto op_info_vec = s.graph_properties.GetOutputProperties(node_name);
|
||||
if (static_cast<int>(op_info_vec.size()) < output_idx)
|
||||
auto op_info_vec =
|
||||
s.graph_properties.GetOutputProperties(shape_inference_node_name);
|
||||
if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
|
||||
return tensorflow::errors::Internal(
|
||||
"accessing output index of: ", output_idx, ", at node: ", node_name,
|
||||
"with output entry from shape_map: ", op_info_vec.size());
|
||||
|
||||
auto op_info = op_info_vec.at(output_idx);
|
||||
"accessing output index of: ", shape_inference_output_idx,
|
||||
", at node: ", shape_inference_node_name,
|
||||
" with output entry from shape_map: ", op_info_vec.size());
|
||||
|
||||
auto op_info = op_info_vec.at(shape_inference_output_idx);
|
||||
tensorflow::DataType tf_dtype = op_info.dtype();
|
||||
input_dtypes.push_back(tf_dtype);
|
||||
|
||||
@ -2294,16 +2318,23 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
<< "' failed";
|
||||
return type_status;
|
||||
}
|
||||
TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
|
||||
|
||||
VLOG(2) << "accessing output index of: " << output_idx
|
||||
<< ", at node: " << node_name
|
||||
<< "with output entry from shape_map: " << op_info_vec.size();
|
||||
|
||||
// TODO(ben,jie): update TRT input format/dimension
|
||||
nvinfer1::DimsCHW input_dim_psuedo_chw;
|
||||
for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
|
||||
|
||||
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
|
||||
// update the code once TRT 4.0 comes out.
|
||||
if (op_info.shape().dim_size() != 4) {
|
||||
string err_str = "Require 4 dimensional input.";
|
||||
StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
|
||||
shape_inference_node_name);
|
||||
return tensorflow::errors::Unimplemented(err_str);
|
||||
}
|
||||
|
||||
for (int i = 1; i < op_info.shape().dim_size(); i++) {
|
||||
VLOG(2) << "dimension: " << i
|
||||
<< " , size: " << op_info.shape().dim(i).size();
|
||||
@ -2312,8 +2343,11 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
|
||||
// TODO(ben,jie): proper way to restore input tensor name?
|
||||
auto input_tensor_name = node_name;
|
||||
if (output_idx != 0) input_tensor_name = StrCat(node_name, ":", output_idx);
|
||||
if (output_idx != 0) {
|
||||
input_tensor_name = StrCat(node_name, ":", output_idx);
|
||||
}
|
||||
|
||||
input_names.push_back(input_tensor_name);
|
||||
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
|
||||
input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
|
||||
|
||||
@ -2377,11 +2411,13 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
tensor->setType(trt_dtype);
|
||||
}
|
||||
|
||||
VLOG(2) << "finished output";
|
||||
VLOG(2) << "Finished processing outputs";
|
||||
|
||||
// Build the engine
|
||||
op_res->builder_->setMaxBatchSize(s.max_batch_size);
|
||||
op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
|
||||
VLOG(0) << "Max batch size= " << s.max_batch_size
|
||||
<< " max workspace size= " << s.max_workspace_size_bytes;
|
||||
|
||||
// Build the TRT op
|
||||
// TODO(sami,ben,jie): proper naming!
|
||||
@ -2475,7 +2511,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||
std::vector<string> input_names;
|
||||
std::vector<tensorflow::DataType> input_dtypes;
|
||||
for (const std::pair<int, int>& input : s.input_inds) {
|
||||
VLOG(2) << "parsing input!!!!!";
|
||||
VLOG(2) << "parsing input. Node id= " << input.first;
|
||||
int node_id = input.first;
|
||||
int output_idx = input.second;
|
||||
tensorflow::Node* node = s.graph.FindNodeId(node_id);
|
||||
|
36
tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt
Normal file
36
tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt
Normal file
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "ClipByValue"
|
||||
in_arg {
|
||||
name: "t"
|
||||
description: <<END
|
||||
A `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "clip_value_min"
|
||||
description: <<END
|
||||
A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
||||
as `t`. The minimum value to clip by.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "clip_value_max"
|
||||
description: <<END
|
||||
A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
||||
as `t`. The maximum value to clip by.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A clipped `Tensor` with the same shape as input 't'.
|
||||
END
|
||||
}
|
||||
summary: "Clips tensor values to a specified min and max."
|
||||
description: <<END
|
||||
Given a tensor `t`, this operation returns a tensor of the same type and
|
||||
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
|
||||
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
|
||||
greater than `clip_value_max` are set to `clip_value_max`.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ClipByValue"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#include <omp.h>
|
||||
#endif
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -47,10 +50,24 @@ thread::ThreadPool* ComputePool(const SessionOptions& options) {
|
||||
}
|
||||
|
||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||
const int32 t = options.config.inter_op_parallelism_threads();
|
||||
if (t != 0) return t;
|
||||
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
||||
if (inter_op != 0) return inter_op;
|
||||
#ifdef INTEL_MKL
|
||||
// MKL library executes ops in parallel using OMP threads
|
||||
// Set inter_op conservatively to avoid thread oversubscription that could
|
||||
// lead to severe perf degradations and OMP resource exhaustion
|
||||
const int mkl_intra_op = omp_get_max_threads();
|
||||
CHECK_GE(mkl_intra_op, 1);
|
||||
const int32 mkl_inter_op = std::max(
|
||||
(port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
|
||||
VLOG(0) << "Creating new thread pool with default inter op setting: "
|
||||
<< mkl_inter_op
|
||||
<< ". Tune using inter_op_parallelism_threads for best performance.";
|
||||
return mkl_inter_op;
|
||||
#else
|
||||
// Default to using the number of cores available in the process.
|
||||
return port::NumSchedulableCPUs();
|
||||
#endif
|
||||
}
|
||||
|
||||
thread::ThreadPool* NewThreadPoolFromSessionOptions(
|
||||
|
@ -11,6 +11,10 @@ load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_protos_grappler",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "static_schedule",
|
||||
@ -537,11 +541,28 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
# This rule is header-only unless the build is static (--config=monolithic). Its
|
||||
# implementation is included directly in the framework shared object.
|
||||
cc_library(
|
||||
name = "custom_graph_optimizer_registry",
|
||||
srcs = ["custom_graph_optimizer_registry.cc"],
|
||||
hdrs = ["custom_graph_optimizer_registry.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":custom_graph_optimizer",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_static(
|
||||
[":custom_graph_optimizer_registry_impl"],
|
||||
),
|
||||
)
|
||||
|
||||
# This rule contains static variables for the optimizer registry. Do not depend
|
||||
# on it directly; use :custom_graph_optimizer_registry, and link against
|
||||
# libtensorflow_framework.so for the registry symbols.
|
||||
cc_library(
|
||||
name = "custom_graph_optimizer_registry_impl",
|
||||
srcs = ["custom_graph_optimizer_registry.cc"],
|
||||
hdrs = ["custom_graph_optimizer_registry.h"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":custom_graph_optimizer",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -3549,6 +3549,7 @@ tf_kernel_library(
|
||||
"pooling_ops_3d_gpu.cu.cc",
|
||||
],
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":conv_2d",
|
||||
":conv_3d",
|
||||
":conv_ops",
|
||||
@ -3559,6 +3560,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
@ -18,9 +18,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER5(UnaryOp, CPU, "Abs", functor::abs, float, Eigen::half, double, int32,
|
||||
int64);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, float, Eigen::half, double, int64);
|
||||
|
225
tensorflow/core/kernels/cwise_op_clip.cc
Normal file
225
tensorflow/core/kernels/cwise_op_clip.cc
Normal file
@ -0,0 +1,225 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_op_clip.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Basic coefficient-wise tenary operations.
|
||||
// This is the case for example of the clip_by_value.
|
||||
// Device: E.g., CPUDevice, GPUDevice.
|
||||
// Functor: defined above. E.g., functor::clip.
|
||||
template <typename Device, typename T>
|
||||
class ClipOp : public OpKernel {
|
||||
public:
|
||||
explicit ClipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& in0 = ctx->input(0);
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
const Tensor& in2 = ctx->input(2);
|
||||
|
||||
auto in0_flat = in0.flat<T>();
|
||||
auto in1_flat = in1.flat<T>();
|
||||
auto in2_flat = in2.flat<T>();
|
||||
const Device& d = ctx->eigen_device<Device>();
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
|
||||
auto out_flat = out->flat<T>();
|
||||
if (in1.shape() == in2.shape()) {
|
||||
if (in0.shape() == in1.shape()) {
|
||||
functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
|
||||
out_flat);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in1.shape()),
|
||||
errors::InvalidArgument(
|
||||
"clip_value_min and clip_value_max must be either of "
|
||||
"the same shape as input, or a scalar. ",
|
||||
"input shape: ", in0.shape().DebugString(),
|
||||
"clip_value_min shape: ", in1.shape().DebugString(),
|
||||
"clip_value_max shape: ", in2.shape().DebugString()));
|
||||
functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
|
||||
out_flat);
|
||||
}
|
||||
} else {
|
||||
if (in0.shape() == in1.shape()) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in2.shape()),
|
||||
errors::InvalidArgument(
|
||||
"clip_value_min and clip_value_max must be either of "
|
||||
"the same shape as input, or a scalar. ",
|
||||
"input shape: ", in0.shape().DebugString(),
|
||||
"clip_value_min shape: ", in1.shape().DebugString(),
|
||||
"clip_value_max shape: ", in2.shape().DebugString()));
|
||||
functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
|
||||
out_flat);
|
||||
} else {
|
||||
OP_REQUIRES(ctx,
|
||||
(in0.shape() == in2.shape() &&
|
||||
TensorShapeUtils::IsScalar(in1.shape())),
|
||||
errors::InvalidArgument(
|
||||
"clip_value_min and clip_value_max must be either of "
|
||||
"the same shape as input, or a scalar. ",
|
||||
"input shape: ", in0.shape().DebugString(),
|
||||
"clip_value_min shape: ", in1.shape().DebugString(),
|
||||
"clip_value_max shape: ", in2.shape().DebugString()));
|
||||
functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
|
||||
out_flat);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
// Unary functor for clip [Tensor, Scalar, Scalar]
|
||||
template <typename T>
|
||||
struct UnaryClipFunc {
|
||||
UnaryClipFunc(const T& value_min, const T& value_max)
|
||||
: value_min(value_min), value_max(value_max) {}
|
||||
const T operator()(const T& value) const {
|
||||
return std::max(std::min(value, value_max), value_min);
|
||||
}
|
||||
T value_min;
|
||||
T value_max;
|
||||
};
|
||||
template <typename T>
|
||||
struct UnaryClipOp<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
|
||||
typename TTypes<T>::ConstFlat& in1_flat,
|
||||
typename TTypes<T>::ConstFlat& in2_flat,
|
||||
typename TTypes<T>::Flat& out_flat) const {
|
||||
out_flat = in0_flat.unaryExpr(UnaryClipFunc<T>(in1_flat(0), in2_flat(0)));
|
||||
}
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Scalar, Tensor]
|
||||
template <typename T>
|
||||
struct BinaryRightClipFunc {
|
||||
explicit BinaryRightClipFunc(const T& value_min) : value_min(value_min) {}
|
||||
const T operator()(const T& value, const T& value_max) const {
|
||||
return std::max(std::min(value, value_max), value_min);
|
||||
}
|
||||
T value_min;
|
||||
};
|
||||
template <typename T>
|
||||
struct BinaryRightClipOp<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
|
||||
typename TTypes<T>::ConstFlat& in1_flat,
|
||||
typename TTypes<T>::ConstFlat& in2_flat,
|
||||
typename TTypes<T>::Flat& out_flat) const {
|
||||
out_flat =
|
||||
in0_flat.binaryExpr(in2_flat, BinaryRightClipFunc<T>(in1_flat(0)));
|
||||
}
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Tensor, Scalar]
|
||||
template <typename T>
|
||||
struct BinaryLeftClipFunc {
|
||||
explicit BinaryLeftClipFunc(const T& value_max) : value_max(value_max) {}
|
||||
const T operator()(const T& value, const T& value_min) const {
|
||||
return std::max(std::min(value, value_max), value_min);
|
||||
}
|
||||
T value_max;
|
||||
};
|
||||
template <typename T>
|
||||
struct BinaryLeftClipOp<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
|
||||
typename TTypes<T>::ConstFlat& in1_flat,
|
||||
typename TTypes<T>::ConstFlat& in2_flat,
|
||||
typename TTypes<T>::Flat& out_flat) const {
|
||||
out_flat =
|
||||
in0_flat.binaryExpr(in1_flat, BinaryLeftClipFunc<T>(in2_flat(0)));
|
||||
}
|
||||
};
|
||||
|
||||
// Ternary functor for clip [Tensor, Tensor, Tensor]
|
||||
template <typename T>
|
||||
struct TernaryClipOp<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
|
||||
typename TTypes<T>::ConstFlat& in1_flat,
|
||||
typename TTypes<T>::ConstFlat& in2_flat,
|
||||
typename TTypes<T>::Flat& out_flat) const {
|
||||
out_flat.device(d) = in0_flat.cwiseMin(in2_flat).cwiseMax(in1_flat);
|
||||
}
|
||||
};
|
||||
|
||||
#define INSTANTIATE_CPU(T) \
|
||||
template struct UnaryClipOp<CPUDevice, T>; \
|
||||
template struct BinaryRightClipOp<CPUDevice, T>; \
|
||||
template struct BinaryLeftClipOp<CPUDevice, T>; \
|
||||
template struct TernaryClipOp<CPUDevice, T>;
|
||||
INSTANTIATE_CPU(Eigen::half);
|
||||
INSTANTIATE_CPU(float);
|
||||
INSTANTIATE_CPU(double);
|
||||
INSTANTIATE_CPU(int8);
|
||||
INSTANTIATE_CPU(int16);
|
||||
INSTANTIATE_CPU(int32);
|
||||
INSTANTIATE_CPU(int64);
|
||||
INSTANTIATE_CPU(uint8);
|
||||
INSTANTIATE_CPU(uint16);
|
||||
#undef INSTANTIATE_CPU
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_CPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ClipByValue").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
ClipOp<CPUDevice, type>);
|
||||
|
||||
REGISTER_CPU_KERNEL(Eigen::half);
|
||||
REGISTER_CPU_KERNEL(float);
|
||||
REGISTER_CPU_KERNEL(double);
|
||||
REGISTER_CPU_KERNEL(int8);
|
||||
REGISTER_CPU_KERNEL(int16);
|
||||
REGISTER_CPU_KERNEL(int32);
|
||||
REGISTER_CPU_KERNEL(int64);
|
||||
REGISTER_CPU_KERNEL(uint8);
|
||||
REGISTER_CPU_KERNEL(uint16);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ClipByValue").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
ClipOp<GPUDevice, type>);
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
REGISTER_GPU_KERNEL(double);
|
||||
REGISTER_GPU_KERNEL(int8);
|
||||
REGISTER_GPU_KERNEL(int16);
|
||||
REGISTER_GPU_KERNEL(int64);
|
||||
REGISTER_GPU_KERNEL(uint8);
|
||||
REGISTER_GPU_KERNEL(uint16);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
REGISTER_KERNEL_BUILDER(Name("ClipByValue")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("t")
|
||||
.HostMemory("clip_value_min")
|
||||
.HostMemory("clip_value_max")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ClipOp<CPUDevice, int32>);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
61
tensorflow/core/kernels/cwise_op_clip.h
Normal file
61
tensorflow/core/kernels/cwise_op_clip.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
// Unary functor for clip [Tensor, Scalar, Scalar]
|
||||
template <typename Device, typename T>
|
||||
struct UnaryClipOp {
|
||||
void operator()(const Device &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const;
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Scalar, Tensor]
|
||||
template <typename Device, typename T>
|
||||
struct BinaryRightClipOp {
|
||||
void operator()(const Device &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const;
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Tensor, Scalar]
|
||||
template <typename Device, typename T>
|
||||
struct BinaryLeftClipOp {
|
||||
void operator()(const Device &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const;
|
||||
};
|
||||
|
||||
// Ternary functor for clip [Tensor, Tensor, Tensor]
|
||||
template <typename Device, typename T>
|
||||
struct TernaryClipOp {
|
||||
void operator()(const Device &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const;
|
||||
};
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
|
134
tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
Normal file
134
tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_op_clip.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
__global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
|
||||
const T *in1, const T *in2, T *out) {
|
||||
CUDA_1D_KERNEL_LOOP(i, size_in) {
|
||||
T value = in2[0] < in0[i] ? in2[0] : in0[i];
|
||||
out[i] = value < in1[0] ? in1[0] : value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
|
||||
const T *in1, const T *in2,
|
||||
T *out) {
|
||||
CUDA_1D_KERNEL_LOOP(i, size_in) {
|
||||
T value = in2[i] < in0[i] ? in2[i] : in0[i];
|
||||
out[i] = value < in1[0] ? in1[0] : value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T *in0,
|
||||
const T *in1, const T *in2, T *out) {
|
||||
CUDA_1D_KERNEL_LOOP(i, size_in) {
|
||||
T value = in2[0] < in0[i] ? in2[0] : in0[i];
|
||||
out[i] = value < in1[i] ? in1[i] : value;
|
||||
}
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Unary functor for clip [Tensor, Scalar, Scalar]
|
||||
template <typename T>
|
||||
struct UnaryClipOp<GPUDevice, T> {
|
||||
void operator()(const GPUDevice &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
|
||||
|
||||
UnaryClipCustomKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
|
||||
out_flat.data());
|
||||
}
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Scalar, Tensor]
|
||||
template <typename T>
|
||||
struct BinaryRightClipOp<GPUDevice, T> {
|
||||
void operator()(const GPUDevice &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
|
||||
|
||||
BinaryRightClipCustomKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
|
||||
out_flat.data());
|
||||
}
|
||||
};
|
||||
|
||||
// Binary functor for clip [Tensor, Tensor, Scalar]
|
||||
template <typename T>
|
||||
struct BinaryLeftClipOp<GPUDevice, T> {
|
||||
void operator()(const GPUDevice &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
|
||||
|
||||
BinaryLeftClipCustomKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
|
||||
out_flat.data());
|
||||
}
|
||||
};
|
||||
|
||||
// Ternary functor for clip [Tensor, Tensor, Tensor]
|
||||
template <typename T>
|
||||
struct TernaryClipOp<GPUDevice, T> {
|
||||
void operator()(const GPUDevice &d, typename TTypes<T>::ConstFlat &in0_flat,
|
||||
typename TTypes<T>::ConstFlat &in1_flat,
|
||||
typename TTypes<T>::ConstFlat &in2_flat,
|
||||
typename TTypes<T>::Flat &out_flat) const {
|
||||
out_flat.device(d) = in0_flat.cwiseMin(in2_flat).cwiseMax(in1_flat);
|
||||
}
|
||||
};
|
||||
|
||||
#define INSTANTIATE_GPU(T) \
|
||||
template struct UnaryClipOp<GPUDevice, T>; \
|
||||
template struct BinaryRightClipOp<GPUDevice, T>; \
|
||||
template struct BinaryLeftClipOp<GPUDevice, T>; \
|
||||
template struct TernaryClipOp<GPUDevice, T>;
|
||||
INSTANTIATE_GPU(Eigen::half);
|
||||
INSTANTIATE_GPU(float);
|
||||
INSTANTIATE_GPU(double);
|
||||
INSTANTIATE_GPU(int8);
|
||||
INSTANTIATE_GPU(int16);
|
||||
INSTANTIATE_GPU(int32);
|
||||
INSTANTIATE_GPU(int64);
|
||||
INSTANTIATE_GPU(uint8);
|
||||
INSTANTIATE_GPU(uint16);
|
||||
#undef INSTANTIATE_GPU
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/eigen_pooling.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
@ -56,7 +57,7 @@ template <typename Device, typename T>
|
||||
static void SpatialMaxPoolWithArgMaxHelper(
|
||||
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
|
||||
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
|
||||
const PoolParameters& params, const Padding& padding) {
|
||||
const PoolParameters& params) {
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
ConstEigenMatrixMap;
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
@ -151,7 +152,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (input_backprop != nullptr) {
|
||||
auto input_backprop_flat = input_backprop->flat<T>();
|
||||
auto out_arg_max_flat = output_arg_max->flat<int64>();
|
||||
auto out_backprop_flat = out_backprop.flat<T>();
|
||||
@ -173,9 +174,9 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
// Although this check is in the inner loop, it is worth its value
|
||||
// so we don't end up with memory corruptions. Our benchmark shows that
|
||||
// the performance impact is quite small
|
||||
CHECK(input_backprop_index >= in_start && input_backprop_index < in_end)
|
||||
<< "Invalid input backprop index: " << input_backprop_index << ", "
|
||||
<< in_start << ", " << in_end;
|
||||
// CHECK(input_backprop_index >= in_start && input_backprop_index <
|
||||
// in_end)
|
||||
FastBoundsCheck(input_backprop_index - in_start, in_end - in_start);
|
||||
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
|
||||
}
|
||||
}
|
||||
@ -293,7 +294,7 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
|
||||
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
|
||||
out_backprop, params, padding_);
|
||||
out_backprop, params);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -869,6 +870,17 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMaxPoolingWithArgmax;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
const Tensor& input, Tensor* output, Tensor* argmax,
|
||||
bool propagate_nans) {
|
||||
Tensor unused;
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
|
||||
context, output, argmax, nullptr, input, unused, params);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
public:
|
||||
@ -921,6 +933,53 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMaxPoolingGradWithArgmax;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
EigenMatrixMap;
|
||||
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
const Tensor& grad_in, const Tensor& argmax,
|
||||
Tensor* grad_out) {
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
auto shard = [&grad_in, &argmax, &grad_out](int64 start, int64 limit) {
|
||||
const int64 batch_size =
|
||||
GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
|
||||
const int64 output_size_per_batch = grad_out->NumElements() / batch_size;
|
||||
const int64 input_size_per_batch = grad_in.NumElements() / batch_size;
|
||||
|
||||
{
|
||||
auto grad_out_flat = grad_out->flat<T>();
|
||||
auto argmax_flat = argmax.flat<int64>();
|
||||
auto grad_in_flat = grad_in.flat<T>();
|
||||
|
||||
const int64 output_start = start * output_size_per_batch;
|
||||
const int64 output_end = limit * output_size_per_batch;
|
||||
EigenMatrixMap inputShard(grad_out_flat.data() + output_start, 1,
|
||||
output_end - output_start);
|
||||
inputShard.setConstant(T(0));
|
||||
|
||||
const int input_start = start * input_size_per_batch;
|
||||
const int input_end = limit * input_size_per_batch;
|
||||
for (int64 index = input_start; index < input_end; index++) {
|
||||
const int64 grad_out_index = argmax_flat(index);
|
||||
CHECK(grad_out_index >= output_start && grad_out_index < output_end)
|
||||
<< "Invalid output gradient index: " << grad_out_index << ", "
|
||||
<< output_start << ", " << output_end;
|
||||
grad_out_flat(grad_out_index) += grad_in_flat(index);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const int64 batch_size = GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
|
||||
const int64 shard_cost = grad_out->NumElements() / batch_size;
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||
shard_cost, shard);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||
public:
|
||||
@ -1309,7 +1368,17 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
||||
.HostMemory("ksize") \
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingGradGradOp<D##Device, T>);
|
||||
MaxPoolingGradGradOp<D##Device, T>) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradWithArgmaxOp<D##Device, T>);
|
||||
|
||||
// Below kernels implemented only for CPU device.
|
||||
#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
|
||||
@ -1374,16 +1443,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingNoMaskV2Op<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
|
@ -16,6 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
|
||||
|
||||
// This file requires the following include because it uses CudaAtomicMax:
|
||||
// #include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
// Unfortunately we can't add the #include, since it breaks compilation for
|
||||
// non-GPU targets. This only breaks in clang, because it's more strict for
|
||||
// template code and CudaAtomicMax is used in template context.
|
||||
|
||||
// This file requires the following include because it uses CudaAtomicMax:
|
||||
// #include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
@ -117,7 +117,11 @@ REGISTER_OP("TakeDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle count_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SkipDataset")
|
||||
.Input("input_dataset: variant")
|
||||
@ -125,7 +129,11 @@ REGISTER_OP("SkipDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle count_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("BytesProducedStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
|
@ -1558,6 +1558,14 @@ REGISTER_OP("Bucketize")
|
||||
.Attr("boundaries: list(float)")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("ClipByValue")
|
||||
.Input("t: T")
|
||||
.Input("clip_value_min: T")
|
||||
.Input("clip_value_max: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: numbertype")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("_MklAddN")
|
||||
.Input("inputs: N * T")
|
||||
|
@ -31,13 +31,14 @@ limitations under the License.
|
||||
__attribute__((__format__(__printf__, string_index, first_to_check)))
|
||||
#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \
|
||||
__attribute__((__format__(__scanf__, string_index, first_to_check)))
|
||||
#elif defined(COMPILER_MSVC)
|
||||
#elif defined(_MSC_VER)
|
||||
// Non-GCC equivalents
|
||||
#define TF_ATTRIBUTE_NORETURN __declspec(noreturn)
|
||||
#define TF_ATTRIBUTE_ALWAYS_INLINE
|
||||
#define TF_ATTRIBUTE_ALWAYS_INLINE __forceinline
|
||||
#define TF_ATTRIBUTE_NOINLINE
|
||||
#define TF_ATTRIBUTE_UNUSED
|
||||
#define TF_ATTRIBUTE_COLD
|
||||
#define TF_ATTRIBUTE_WEAK
|
||||
#define TF_MUST_USE_RESULT
|
||||
#define TF_PACKED
|
||||
#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check)
|
||||
@ -57,7 +58,7 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
// Control visiblity outside .so
|
||||
#if defined(COMPILER_MSVC)
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
@ -65,7 +66,7 @@ limitations under the License.
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_EXPORT __attribute__((visibility("default")))
|
||||
#endif // COMPILER_MSVC
|
||||
#endif // _WIN32
|
||||
|
||||
#ifdef __has_builtin
|
||||
#define TF_HAS_BUILTIN(x) __has_builtin(x)
|
||||
|
@ -148,19 +148,7 @@ viewing. Do not include url parameters in the source code URL.
|
||||
Before building the documentation, you must first set up your environment by
|
||||
doing the following:
|
||||
|
||||
1. If pip isn't installed on your machine, install it now by issuing the
|
||||
following command:
|
||||
|
||||
$ sudo easy_install pip
|
||||
|
||||
2. Use pip to install codegen, mock, and pandas by issuing the following
|
||||
command (Note: If you are using
|
||||
a [virtualenv](https://virtualenv.pypa.io/en/stable/) to manage your
|
||||
dependencies, you may not want to use sudo for these installations):
|
||||
|
||||
$ sudo pip install codegen mock pandas
|
||||
|
||||
3. If bazel is not installed on your machine, install it now. If you are on
|
||||
1. If bazel is not installed on your machine, install it now. If you are on
|
||||
Linux, install bazel by issuing the following command:
|
||||
|
||||
$ sudo apt-get install bazel # Linux
|
||||
@ -168,10 +156,10 @@ following command:
|
||||
If you are on Mac OS, find bazel installation instructions on
|
||||
[this page](https://bazel.build/versions/master/docs/install.html#mac-os-x).
|
||||
|
||||
4. Change directory to the top-level `tensorflow` directory of the TensorFlow
|
||||
2. Change directory to the top-level `tensorflow` directory of the TensorFlow
|
||||
source code.
|
||||
|
||||
5. Run the `configure` script and answer its prompts appropriately for your
|
||||
3. Run the `configure` script and answer its prompts appropriately for your
|
||||
system.
|
||||
|
||||
$ ./configure
|
||||
|
@ -530,56 +530,58 @@ form [described below](#attr_types).
|
||||
|
||||
For example, if you'd like the `ZeroOut` op to preserve a user-specified index,
|
||||
instead of only the 0th element, you can register the op like so:
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
REGISTER\_OP("ZeroOut")
|
||||
<b>.Attr("preserve\_index: int")</b>
|
||||
.Input("to\_zero: int32")
|
||||
```c++
|
||||
REGISTER_OP("ZeroOut")
|
||||
.Attr("preserve_index: int")
|
||||
.Input("to_zero: int32")
|
||||
.Output("zeroed: int32");
|
||||
</code></pre>
|
||||
```
|
||||
|
||||
(Note that the set of [attribute types](#attr_types) is different from the
|
||||
@{tf.DType$tensor types} used for inputs and outputs.)
|
||||
|
||||
Your kernel can then access this attr in its constructor via the `context`
|
||||
parameter:
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
```c++
|
||||
class ZeroOutOp : public OpKernel {
|
||||
public:
|
||||
explicit ZeroOutOp(OpKernelConstruction\* context) : OpKernel(context) {<b>
|
||||
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// Get the index of the value to preserve
|
||||
OP\_REQUIRES\_OK(context,
|
||||
context->GetAttr("preserve\_index", &preserve\_index\_));
|
||||
// Check that preserve\_index is positive
|
||||
OP\_REQUIRES(context, preserve\_index_ >= 0,
|
||||
errors::InvalidArgument("Need preserve\_index >= 0, got ",
|
||||
preserve\_index_));
|
||||
</b>}
|
||||
void Compute(OpKernelContext\* context) override {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("preserve_index", &preserve_index_));
|
||||
// Check that preserve_index is positive
|
||||
OP_REQUIRES(context, preserve_index_ >= 0,
|
||||
errors::InvalidArgument("Need preserve_index >= 0, got ",
|
||||
preserve_index_));
|
||||
}
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// ...
|
||||
}
|
||||
<b>private:
|
||||
int preserve\_index\_;</b>
|
||||
private:
|
||||
int preserve_index_;
|
||||
};
|
||||
</code></pre>
|
||||
```
|
||||
|
||||
which can then be used in the `Compute` method:
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
void Compute(OpKernelContext\* context) override {
|
||||
```c++
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// ...
|
||||
<br/>
|
||||
<b>// We're using saved attr to validate potentially dynamic input
|
||||
// So we check that preserve\_index is in range
|
||||
OP\_REQUIRES(context, preserve\_index_ < input.dimension(0),
|
||||
errors::InvalidArgument("preserve\_index out of range"));<br/>
|
||||
</b>// Set all the elements of the output tensor to 0
|
||||
|
||||
// We're using saved attr to validate potentially dynamic input
|
||||
// So we check that preserve_index is in range
|
||||
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
|
||||
errors::InvalidArgument("preserve_index out of range"));
|
||||
|
||||
// Set all the elements of the output tensor to 0
|
||||
const int N = input.size();
|
||||
for (int i = 0; i < N; i++) {
|
||||
output\_flat(i) = 0;
|
||||
}<br/>
|
||||
<b>// Preserve the requested input value
|
||||
output\_flat(preserve\_index\_) = input(preserve\_index\_);</b>
|
||||
}
|
||||
|
||||
// Preserve the requested input value
|
||||
output_flat(preserve_index_) = input(preserve_index_);
|
||||
}
|
||||
</code></pre>
|
||||
```
|
||||
|
||||
#### Attr types
|
||||
|
||||
@ -725,12 +727,12 @@ you would then register an `OpKernel` for each supported type.
|
||||
|
||||
For instance, if you'd like the `ZeroOut` op to work on `float`s
|
||||
in addition to `int32`s, your op registration might look like:
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
REGISTER\_OP("ZeroOut")
|
||||
<b>.Attr("T: {float, int32}")</b>
|
||||
.Input("to\_zero: <b>T</b>")
|
||||
.Output("zeroed: <b>T</b>");
|
||||
</code></pre>
|
||||
```c++
|
||||
REGISTER_OP("ZeroOut")
|
||||
.Attr("T: {float, int32}")
|
||||
.Input("to_zero: T")
|
||||
.Output("zeroed: T");
|
||||
```
|
||||
|
||||
Your op registration now specifies that the input's type must be `float`, or
|
||||
`int32`, and that its output will be the same type, since both have type `T`.
|
||||
@ -790,66 +792,73 @@ Your op registration now specifies that the input's type must be `float`, or
|
||||
> """
|
||||
> ```
|
||||
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
\#include "tensorflow/core/framework/op_kernel.h"<br/>
|
||||
class ZeroOut<b>Int32</b>Op : public OpKernel {
|
||||
```c++
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
class ZeroOutInt32Op : public OpKernel {
|
||||
// as before
|
||||
};<br/>
|
||||
class ZeroOut<b>Float</b>Op : public OpKernel {
|
||||
};
|
||||
|
||||
class ZeroOutFloatOp : public OpKernel {
|
||||
public:
|
||||
explicit ZeroOut<b>Float</b>Op(OpKernelConstruction\* context)
|
||||
: OpKernel(context) {}<br/>
|
||||
void Compute(OpKernelContext\* context) override {
|
||||
explicit ZeroOutFloatOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Grab the input tensor
|
||||
const Tensor& input\_tensor = context->input(0);
|
||||
auto input = input\_tensor.flat<<b>float</b>>();<br/>
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
auto input = input_tensor.flat<float>();
|
||||
|
||||
// Create an output tensor
|
||||
Tensor* output = NULL;
|
||||
OP\_REQUIRES\_OK(context,
|
||||
context->allocate\_output(0, input_tensor.shape(), &output));
|
||||
auto output\_flat = output->template flat<<b>float</b>>();<br/>
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input_tensor.shape(), &output));
|
||||
auto output_flat = output->template flat<float>();
|
||||
|
||||
// Set all the elements of the output tensor to 0
|
||||
const int N = input.size();
|
||||
for (int i = 0; i < N; i++) {
|
||||
output\_flat(i) = 0;
|
||||
}<br/>
|
||||
for (int i = 0; i < N; i++) {
|
||||
output_flat(i) = 0;
|
||||
}
|
||||
|
||||
// Preserve the first input value
|
||||
if (N > 0) output\_flat(0) = input(0);
|
||||
if (N > 0) output_flat(0) = input(0);
|
||||
}
|
||||
};<br/><b>
|
||||
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
|
||||
};
|
||||
|
||||
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
|
||||
// in the op registration above) must be "int32" to use this template
|
||||
// instantiation.</b>
|
||||
REGISTER\_KERNEL\_BUILDER(
|
||||
// instantiation.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ZeroOut")
|
||||
.Device(DEVICE\_CPU)
|
||||
<b>.TypeConstraint<int32>("T"),</b>
|
||||
ZeroOutOp<b>Int32</b>);
|
||||
<b>REGISTER\_KERNEL\_BUILDER(
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int32>("T"),
|
||||
ZeroOutOpInt32);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ZeroOut")
|
||||
.Device(DEVICE\_CPU)
|
||||
.TypeConstraint<float>("T"),
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T"),
|
||||
ZeroOutFloatOp);
|
||||
</b></code></pre>
|
||||
```
|
||||
|
||||
> To preserve [backwards compatibility](#backwards-compatibility), you should
|
||||
> specify a [default value](#default-values-constraints) when adding an attr to
|
||||
> an existing op:
|
||||
>
|
||||
> <pre class="prettyprint"><code class="lang-cpp">
|
||||
> REGISTER\_OP("ZeroOut")
|
||||
> <b>.Attr("T: {float, int32} = DT_INT32")</b>
|
||||
> .Input("to\_zero: T")
|
||||
> ```c++
|
||||
> REGISTER_OP("ZeroOut")
|
||||
> .Attr("T: {float, int32} = DT_INT32")
|
||||
> .Input("to_zero: T")
|
||||
> .Output("zeroed: T")
|
||||
> </code></pre>
|
||||
> ```
|
||||
|
||||
Let's say you wanted to add more types, say `double`:
|
||||
<pre class="prettyprint"><code class="lang-cpp">
|
||||
REGISTER\_OP("ZeroOut")
|
||||
<b>.Attr("T: {float, <b>double,</b> int32}")</b>
|
||||
.Input("to\_zero: <b>T</b>")
|
||||
.Output("zeroed: <b>T</b>");
|
||||
</code></pre>
|
||||
```c++
|
||||
REGISTER_OP("ZeroOut")
|
||||
.Attr("T: {float, double, int32}")
|
||||
.Input("to_zero: T")
|
||||
.Output("zeroed: T");
|
||||
```
|
||||
|
||||
Instead of writing another `OpKernel` with redundant code as above, often you
|
||||
will be able to use a C++ template instead. You will still have one kernel
|
||||
|
@ -546,7 +546,7 @@ In brief, here's what the three graphs tell you:
|
||||
|
||||
* accuracy: The accuracy is recorded by the following two lines:
|
||||
|
||||
* `eval_metric_ops={'my_accuracy': accuracy})`, during evaluation.
|
||||
* `eval_metric_ops={'my_accuracy': accuracy}`, during evaluation.
|
||||
* `tf.summary.scalar('accuracy', accuracy[1])`, during training.
|
||||
|
||||
These tensorboard graphs are one of the main reasons it's important to pass a
|
||||
|
@ -113,6 +113,6 @@ If executing `a.out` fails, ask yourself the following questions:
|
||||
* Did you export those environment variables?
|
||||
|
||||
If you are still seeing build or execution error messages, search (or post to)
|
||||
[StackOverflow](www.stackoverflow.com/questions/tagged/tensorflow) for
|
||||
[StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow) for
|
||||
possible solutions.
|
||||
|
||||
|
@ -475,7 +475,7 @@ optimizations.
|
||||
### TensorFlow with Intel® MKL DNN
|
||||
|
||||
Intel® has added optimizations to TensorFlow for Intel® Xeon® and Intel® Xeon
|
||||
Phi™ though the use of Intel® Math Kernel Library for Deep Neural Networks
|
||||
Phi™ through the use of the Intel® Math Kernel Library for Deep Neural Networks
|
||||
(Intel® MKL-DNN) optimized primitives. The optimizations also provide speedups
|
||||
for the consumer line of processors, e.g. i5 and i7 Intel processors. The Intel
|
||||
published paper
|
||||
@ -581,9 +581,9 @@ Each variable that impacts performance is discussed below.
|
||||
for optimal settings.
|
||||
|
||||
* **intra_op_parallelism_threads**: Setting this equal to the number of
|
||||
physical cores is recommended. Setting the value to 0, which is the default
|
||||
and will result in the value being set to the number of logical cores, is an
|
||||
option to try for some architectures. This value and `OMP_NUM_THREADS`
|
||||
physical cores is recommended. Setting the value to 0, which is the default,
|
||||
results in the value being set to the number of logical cores - this is an
|
||||
alternate option to try for some architectures. This value and `OMP_NUM_THREADS`
|
||||
should be equal.
|
||||
|
||||
* **inter_op_parallelism_threads**: Setting this equal to the number of
|
||||
|
@ -4,29 +4,28 @@
|
||||
|
||||
[TOC]
|
||||
|
||||
TensorFlow debugger (**tfdbg**) is a specialized debugger for TensorFlow. It
|
||||
lets you view the internal structure and states of running TensorFlow graphs
|
||||
during training and inference, which is difficult to debug with general-purpose
|
||||
debuggers such as Python's `pdb` due to TensorFlow's computation-graph paradigm.
|
||||
`tfdbg` is a specialized debugger for TensorFlow. It lets you view the internal
|
||||
structure and states of running TensorFlow graphs during training and inference,
|
||||
which is difficult to debug with general-purpose debuggers such as Python's `pdb`
|
||||
due to TensorFlow's computation-graph paradigm.
|
||||
|
||||
> NOTE: TensorFlow debugger uses a
|
||||
> [curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based
|
||||
> text user interface. On Mac OS X, the `ncurses` library is required and can
|
||||
> be installed with `brew install homebrew/dupes/ncurses`. On Windows, curses
|
||||
> isn't as well supported, so a
|
||||
> [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based interface can
|
||||
> be used with tfdbg by installing `pyreadline` with pip.
|
||||
> If you use Anaconda3, you can install it with a command
|
||||
> such as `"C:\Program Files\Anaconda3\Scripts\pip.exe" install pyreadline`.
|
||||
> Unofficial Windows curses packages can be downloaded
|
||||
> [here](https://www.lfd.uci.edu/~gohlke/pythonlibs/#curses), then subsequently
|
||||
> installed using `pip install <your_version>.whl`, however curses on Windows
|
||||
> may not work as reliably as curses on Linux or Mac.
|
||||
This guide focuses on the command-line interface (CLI) of `tfdbg`. For guide on
|
||||
how to use the graphical user interface (GUI) of tfdbg, i.e., the
|
||||
**TensorBoard Debugger Plugin**, please visit
|
||||
[its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
|
||||
|
||||
> NOTE: This guide focuses on the command-line interface (CLI) of tfdbg. For
|
||||
> guide on how to use the graphical user interface (GUI) of tfdbg, i.e., the
|
||||
> **TensorBoard Debugger Plugin**, please visit
|
||||
> [its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
|
||||
Note: The TensorFlow debugger uses a
|
||||
[curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text
|
||||
user interface. On Mac OS X, the `ncurses` library is required and can be
|
||||
installed with `brew install homebrew/dupes/ncurses`. On Windows, curses isn't as
|
||||
well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based
|
||||
interface can be used with tfdbg by installing `pyreadline` with `pip`. If you
|
||||
use Anaconda3, you can install it with a command such as
|
||||
`"C:\Program Files\Anaconda3\Scripts\pip.exe" install pyreadline`. Unofficial
|
||||
Windows curses packages can be downloaded
|
||||
[here](https://www.lfd.uci.edu/~gohlke/pythonlibs/#curses), then subsequently
|
||||
installed using `pip install <your_version>.whl`, however curses on Windows may
|
||||
not work as reliably as curses on Linux or Mac.
|
||||
|
||||
This tutorial demonstrates how to use the **tfdbg** CLI to debug the appearance
|
||||
of [`nan`s](https://en.wikipedia.org/wiki/NaN)
|
||||
@ -748,16 +747,16 @@ There are three possible workarounds or solutions:
|
||||
to which tfdbg dumps the debug data. You can use it to let tfdbg dump the
|
||||
debug data on a disk with larger free space. For example:
|
||||
|
||||
``` python
|
||||
# For LocalCLIDebugWrapperSession
|
||||
sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space")
|
||||
|
||||
# For LocalCLIDebugHook
|
||||
hooks = [tf_debug.LocalCLIDebugHook(dump_root="/with/lots/of/space")]
|
||||
```
|
||||
```python
|
||||
# For LocalCLIDebugWrapperSession
|
||||
sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space")
|
||||
|
||||
# For LocalCLIDebugHook
|
||||
hooks = [tf_debug.LocalCLIDebugHook(dump_root="/with/lots/of/space")]
|
||||
```
|
||||
Make sure that the directory pointed to by dump_root is empty or nonexistent.
|
||||
tfdbg cleans up the dump directories before exiting.
|
||||
`tfdbg` cleans up the dump directories before exiting.
|
||||
|
||||
* Reduce the batch size used during the runs.
|
||||
* Use the filtering options of tfdbg's `run` command to watch only specific
|
||||
nodes in the graph. For example:
|
||||
|
@ -835,6 +835,7 @@ py_library(
|
||||
srcs = ["framework/tensor_shape.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dtypes",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
|
@ -651,6 +651,11 @@ QUANTIZED_DTYPES = frozenset([
|
||||
])
|
||||
tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
|
||||
|
||||
_PYTHON_TO_TF = {
|
||||
float: float32,
|
||||
bool: bool,
|
||||
}
|
||||
|
||||
|
||||
@tf_export("as_dtype")
|
||||
def as_dtype(type_value):
|
||||
@ -682,6 +687,11 @@ def as_dtype(type_value):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return _PYTHON_TO_TF[type_value]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if isinstance(type_value, np.dtype):
|
||||
# The numpy dtype for strings is variable length. We can not compare
|
||||
# dtype with a single constant (np.string does not exist) to decide
|
||||
|
@ -295,6 +295,10 @@ class TypesTest(test_util.TensorFlowTestCase):
|
||||
self.assertNotEqual(dtypes.int32, int)
|
||||
self.assertNotEqual(dtypes.float64, 2.1)
|
||||
|
||||
def testPythonTypesConversion(self):
|
||||
self.assertIs(dtypes.float32, dtypes.as_dtype(float))
|
||||
self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
|
||||
|
||||
def testReduce(self):
|
||||
for enum in dtypes._TYPE_TO_STRING:
|
||||
dtype = dtypes.DType(enum)
|
||||
@ -307,3 +311,4 @@ class TypesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
||||
|
@ -37,7 +37,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_logging_ops
|
||||
@ -1362,7 +1361,7 @@ class UnrollLSTMTest(test.TestCase):
|
||||
value=math_ops.matmul(xm, weights), num_or_size_splits=4, axis=1)
|
||||
new_c = math_ops.sigmoid(f_g) * cprev + math_ops.sigmoid(
|
||||
i_g) * math_ops.tanh(i_i)
|
||||
new_c = clip_ops.clip_by_value(new_c, -50.0, 50.0)
|
||||
new_c = math_ops.maximum(math_ops.minimum(new_c, 50.0), -50.0)
|
||||
new_m = math_ops.sigmoid(o_g) * math_ops.tanh(new_c)
|
||||
return new_m, new_c
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -30,6 +31,8 @@ class Dimension(object):
|
||||
"""Creates a new Dimension with the given value."""
|
||||
if value is None:
|
||||
self._value = None
|
||||
elif isinstance(value, dtypes.DType):
|
||||
raise TypeError("Cannot convert %s to Dimension" % value)
|
||||
else:
|
||||
self._value = int(value)
|
||||
if (not isinstance(value, compat.bytes_or_text_types) and
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -184,6 +185,10 @@ class DimensionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(str(tensor_shape.Dimension(7)), "7")
|
||||
self.assertEqual(str(tensor_shape.Dimension(None)), "?")
|
||||
|
||||
def testUnsupportedType(self):
|
||||
with self.assertRaises(TypeError):
|
||||
tensor_shape.Dimension(dtypes.string)
|
||||
|
||||
def testMod(self):
|
||||
four = tensor_shape.Dimension(4)
|
||||
nine = tensor_shape.Dimension(9)
|
||||
|
@ -19,9 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -160,13 +160,11 @@ def ask_to_proceed_with_overwrite(filepath):
|
||||
Returns:
|
||||
True if we can proceed with overwrite, False otherwise.
|
||||
"""
|
||||
get_input = input
|
||||
if sys.version_info[:2] <= (2, 7):
|
||||
get_input = raw_input
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath))
|
||||
while overwrite not in ['y', 'n']:
|
||||
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
|
||||
overwrite = six.moves.input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath)).strip().lower()
|
||||
while overwrite not in ('y', 'n'):
|
||||
overwrite = six.moves.input('Enter "y" (overwrite) or "n" '
|
||||
'(cancel).').strip().lower()
|
||||
if overwrite == 'n':
|
||||
return False
|
||||
print('[TIP] Next time specify overwrite=True!')
|
||||
|
@ -19,16 +19,33 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ClipTest(test.TestCase):
|
||||
|
||||
def DISABLED_testClipByValueGradient(self):
|
||||
inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32)
|
||||
outputs_1 = clip_ops.clip_by_value(inputs, 0.5, 3.5)
|
||||
min_val = constant_op.constant([0.5, 0.5, 0.5, 0.5], dtype=dtypes.float32)
|
||||
max_val = constant_op.constant([3.5, 3.5, 3.5, 3.5], dtype=dtypes.float32)
|
||||
outputs_2 = clip_ops.clip_by_value(inputs, min_val, max_val)
|
||||
with self.test_session():
|
||||
error_1 = gradient_checker.compute_gradient_error(inputs, [4], outputs_1,
|
||||
[4])
|
||||
self.assertLess(error_1, 1e-4)
|
||||
|
||||
error_2 = gradient_checker.compute_gradient_error(inputs, [4], outputs_2,
|
||||
[4])
|
||||
self.assertLess(error_2, 1e-4)
|
||||
|
||||
# ClipByValue test
|
||||
def testClipByValue(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
|
||||
np_ans = [[-4.4, 2.0, 3.0], [4.0, 4.4, 4.4]]
|
||||
clip_value = 4.4
|
||||
@ -37,8 +54,76 @@ class ClipTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
# [Tensor, Scalar, Scalar]
|
||||
def DISABLED_testClipByValue0Type(self):
|
||||
for dtype in [
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
|
||||
dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
|
||||
]:
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
|
||||
np_ans = [[2, 2, 3], [4, 4, 4]]
|
||||
clip_value_min = 2
|
||||
clip_value_max = 4
|
||||
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
tf_ans = ans.eval()
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
# [Tensor, Tensor, Scalar]
|
||||
def DISABLED_testClipByValue1Type(self):
|
||||
for dtype in [
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
|
||||
dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
|
||||
]:
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
|
||||
np_ans = [[2, 2, 3], [4, 4, 4]]
|
||||
clip_value_min = constant_op.constant(
|
||||
[2, 2, 2, 3, 3, 3], shape=[2, 3], dtype=dtype)
|
||||
clip_value_max = 4
|
||||
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
tf_ans = ans.eval()
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
# [Tensor, Scalar, Tensor]
|
||||
def DISABLED_testClipByValue2Type(self):
|
||||
for dtype in [
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
|
||||
dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
|
||||
]:
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
|
||||
np_ans = [[4, 4, 4], [4, 5, 6]]
|
||||
clip_value_min = 4
|
||||
clip_value_max = constant_op.constant(
|
||||
[6, 6, 6, 6, 6, 6], shape=[2, 3], dtype=dtype)
|
||||
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
tf_ans = ans.eval()
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
# [Tensor, Tensor, Tensor]
|
||||
def DISABLED_testClipByValue3Type(self):
|
||||
for dtype in [
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
|
||||
dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
|
||||
]:
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
|
||||
np_ans = [[2, 2, 3], [5, 5, 6]]
|
||||
clip_value_min = constant_op.constant(
|
||||
[2, 2, 2, 5, 5, 5], shape=[2, 3], dtype=dtype)
|
||||
clip_value_max = constant_op.constant(
|
||||
[5, 5, 5, 7, 7, 7], shape=[2, 3], dtype=dtype)
|
||||
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
tf_ans = ans.eval()
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
def testClipByValueBadShape(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3, 1])
|
||||
# Use a nonsensical shape.
|
||||
clip = constant_op.constant([1.0, 2.0])
|
||||
@ -48,6 +133,7 @@ class ClipTest(test.TestCase):
|
||||
_ = clip_ops.clip_by_value(x, 1.0, clip)
|
||||
|
||||
def testClipByValueNonFinite(self):
|
||||
# TODO(b/78016351): Enable test on GPU once the bug is fixed.
|
||||
with self.test_session():
|
||||
x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
|
||||
np_ans = [float('NaN'), 4.0, -4.0]
|
||||
@ -60,7 +146,7 @@ class ClipTest(test.TestCase):
|
||||
# ClipByNorm tests
|
||||
def testClipByNormClipped(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Norm of x = sqrt(3^2 + 4^2) = 5
|
||||
np_ans = [[-2.4, 0.0, 0.0], [3.2, 0.0, 0.0]]
|
||||
@ -76,7 +162,7 @@ class ClipTest(test.TestCase):
|
||||
self.assertAllClose(np_ans, tf_ans_tensor)
|
||||
|
||||
def testClipByNormBadShape(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
|
||||
# Use a nonsensical shape.
|
||||
clip = constant_op.constant([1.0, 2.0])
|
||||
@ -85,7 +171,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByNormNotClipped(self):
|
||||
# No norm clipping when clip_norm >= 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Norm of x = sqrt(3^2 + 4^2) = 5
|
||||
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]]
|
||||
@ -97,7 +183,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByNormZero(self):
|
||||
# No norm clipping when norm = 0
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Norm = 0, no changes
|
||||
np_ans = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
|
||||
@ -109,7 +195,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByNormClippedWithDim0(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
|
||||
# Norm of x[:, 0] = sqrt(3^2 + 4^2) = 5, x[:, 2] = 3
|
||||
np_ans = [[-2.4, 0.0, 0.0], [3.2, 0.0, 3.0]]
|
||||
@ -121,7 +207,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByNormClippedWithDim1(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
|
||||
# Norm of x[0, :] = 3, x[1, :] = sqrt(3^2 + 4^2) = 5
|
||||
np_ans = [[-3.0, 0.0, 0.0], [3.2, 0.0, 2.4]]
|
||||
@ -133,7 +219,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByNormNotClippedWithAxes(self):
|
||||
# No norm clipping when clip_norm >= 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
|
||||
# Norm of x[0, :] = 3, x[1, :] = sqrt(3^2 + 4^2) = 5
|
||||
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 3.0]]
|
||||
@ -146,7 +232,7 @@ class ClipTest(test.TestCase):
|
||||
# ClipByGlobalNorm tests
|
||||
def testClipByGlobalNormClipped(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = constant_op.constant([1.0, -2.0])
|
||||
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
|
||||
@ -167,7 +253,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByGlobalNormClippedTensor(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = constant_op.constant([1.0, -2.0])
|
||||
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
|
||||
@ -188,7 +274,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByGlobalNormSupportsNone(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = constant_op.constant([1.0, -2.0])
|
||||
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
|
||||
@ -211,7 +297,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByGlobalNormWithIndexedSlicesClipped(self):
|
||||
# Norm clipping when clip_norm < 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = ops.IndexedSlices(
|
||||
constant_op.constant([1.0, -2.0]), constant_op.constant([3, 4]))
|
||||
@ -244,7 +330,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByGlobalNormNotClipped(self):
|
||||
# No norm clipping when clip_norm >= 5
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = constant_op.constant([1.0, -2.0])
|
||||
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
|
||||
@ -263,7 +349,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByGlobalNormZero(self):
|
||||
# No norm clipping when norm = 0
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x0 = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
|
||||
x1 = constant_op.constant([0.0, 0.0])
|
||||
# Norm = 0, no changes
|
||||
@ -282,7 +368,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByAverageNormClipped(self):
|
||||
# Norm clipping when average clip_norm < 0.83333333
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
|
||||
np_ans = [[-2.88, 0.0, 0.0], [3.84, 0.0, 0.0]]
|
||||
@ -294,7 +380,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByAverageNormClippedTensor(self):
|
||||
# Norm clipping when average clip_norm < 0.83333333
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
|
||||
np_ans = [[-2.88, 0.0, 0.0], [3.84, 0.0, 0.0]]
|
||||
@ -306,7 +392,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByAverageNormNotClipped(self):
|
||||
# No norm clipping when average clip_norm >= 0.83333333
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
|
||||
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]]
|
||||
@ -318,7 +404,7 @@ class ClipTest(test.TestCase):
|
||||
|
||||
def testClipByAverageNormZero(self):
|
||||
# No norm clipping when average clip_norm = 0
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
|
||||
# Average norm = 0, no changes
|
||||
np_ans = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
|
||||
|
@ -817,9 +817,6 @@ class PoolingTest(test.TestCase):
|
||||
cpu_val, gpu_val, half_rtol=0.01, half_atol=0.01)
|
||||
|
||||
def testMaxPoolingWithArgmax(self):
|
||||
# MaxPoolWithArgMax is implemented only on CUDA.
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
t = constant_op.constant(tensor_input, shape=[1, 3, 3, 1])
|
||||
@ -836,9 +833,6 @@ class PoolingTest(test.TestCase):
|
||||
self.assertAllEqual(argmax.ravel(), [0, 1, 3, 5])
|
||||
|
||||
def testMaxPoolingGradWithArgmax(self):
|
||||
# MaxPoolWithArgMax is implemented only on CUDA.
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
|
||||
tensor_input = [11.0, 12.0, 13.0, 14.0]
|
||||
tensor_argmax = list(np.array([0, 1, 3, 5], dtype=np.int64))
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -70,6 +71,35 @@ def clip_by_value(t, clip_value_min, clip_value_max,
|
||||
_ = t.shape.merge_with(t_max.shape)
|
||||
|
||||
return t_max
|
||||
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
|
||||
# return gen_math_ops.clip_by_value(
|
||||
# t, clip_value_min, clip_value_max, name=name)
|
||||
|
||||
|
||||
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
|
||||
# @ops.RegisterGradient("ClipByValue")
|
||||
def _clip_by_value_grad(op, grad):
|
||||
"""Returns grad of clip_by_value."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
z = op.inputs[2]
|
||||
gdtype = grad.dtype
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
sz = array_ops.shape(z)
|
||||
gradshape = array_ops.shape(grad)
|
||||
zeros = array_ops.zeros(gradshape, gdtype)
|
||||
xymask = math_ops.less(x, y)
|
||||
xzmask = math_ops.greater(x, z)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz)
|
||||
xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad)
|
||||
ygrad = array_ops.where(xymask, grad, zeros)
|
||||
zgrad = array_ops.where(xzmask, grad, zeros)
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
|
||||
gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
|
||||
gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz)
|
||||
return (gx, gy, gz)
|
||||
|
||||
|
||||
@tf_export("clip_by_norm")
|
||||
|
395
tensorflow/python/ops/hidden_ops.txt
Normal file
395
tensorflow/python/ops/hidden_ops.txt
Normal file
@ -0,0 +1,395 @@
|
||||
# array_ops
|
||||
BatchToSpace
|
||||
BroadcastArgs
|
||||
BroadcastGradientArgs
|
||||
ConcatOffset
|
||||
Concat
|
||||
ConcatV2
|
||||
ConjugateTranspose
|
||||
Const
|
||||
DebugGradientIdentity
|
||||
DebugGradientRefIdentity
|
||||
EditDistance
|
||||
ExpandDims
|
||||
ListDiff
|
||||
MirrorPad
|
||||
MirrorPadGrad
|
||||
OneHot
|
||||
Pack
|
||||
Pad
|
||||
PadV2
|
||||
ParallelConcat
|
||||
Placeholder
|
||||
RefIdentity
|
||||
Reverse
|
||||
Snapshot
|
||||
SpaceToBatch
|
||||
Split
|
||||
SplitV
|
||||
Squeeze
|
||||
Slice
|
||||
TileGrad # Exported through array_grad instead of array_ops.
|
||||
ZerosLike # TODO(josh11b): Use this instead of the Python version.
|
||||
Unique
|
||||
UniqueV2
|
||||
UniqueWithCounts
|
||||
UniqueWithCountsV2
|
||||
Unpack
|
||||
|
||||
# candidate_sampling_ops
|
||||
AllCandidateSampler
|
||||
ComputeAccidentalHits
|
||||
FixedUnigramCandidateSampler
|
||||
LearnedUnigramCandidateSampler
|
||||
LogUniformCandidateSampler
|
||||
ThreadUnsafeUnigramCandidateSampler
|
||||
UniformCandidateSampler
|
||||
|
||||
# checkpoint_ops
|
||||
GenerateVocabRemapping
|
||||
LoadAndRemapMatrix
|
||||
|
||||
|
||||
# control_flow_ops
|
||||
Switch
|
||||
Merge
|
||||
RefMerge
|
||||
Exit
|
||||
RefExit
|
||||
|
||||
# ctc_ops
|
||||
CTCLoss
|
||||
CTCGreedyDecoder
|
||||
CTCBeamSearchDecoder
|
||||
|
||||
# data_flow_ops
|
||||
Barrier
|
||||
BarrierClose
|
||||
BarrierIncompleteSize
|
||||
BarrierInsertMany
|
||||
BarrierReadySize
|
||||
BarrierTakeMany
|
||||
DeleteSessionTensor
|
||||
FakeQueue
|
||||
FIFOQueue
|
||||
FIFOQueueV2
|
||||
GetSessionHandle
|
||||
GetSessionHandleV2
|
||||
GetSessionTensor
|
||||
HashTable
|
||||
HashTableV2
|
||||
InitializeTable
|
||||
InitializeTableV2
|
||||
InitializeTableFromTextFile
|
||||
InitializeTableFromTextFileV2
|
||||
LookupTableExport
|
||||
LookupTableExportV2
|
||||
LookupTableFind
|
||||
LookupTableFindV2
|
||||
LookupTableImport
|
||||
LookupTableImportV2
|
||||
LookupTableInsert
|
||||
LookupTableInsertV2
|
||||
LookupTableSize
|
||||
LookupTableSizeV2
|
||||
MutableDenseHashTable
|
||||
MutableDenseHashTableV2
|
||||
MutableHashTable
|
||||
MutableHashTableV2
|
||||
MutableHashTableOfTensors
|
||||
MutableHashTableOfTensorsV2
|
||||
Mutex
|
||||
MutexAcquire
|
||||
MutexRelease
|
||||
PaddingFIFOQueue
|
||||
PaddingFIFOQueueV2
|
||||
PriorityQueue
|
||||
PriorityQueueV2
|
||||
QueueClose
|
||||
QueueCloseV2
|
||||
QueueDequeue
|
||||
QueueDequeueV2
|
||||
QueueDequeueMany
|
||||
QueueDequeueManyV2
|
||||
QueueDequeueUpTo
|
||||
QueueDequeueUpToV2
|
||||
QueueEnqueue
|
||||
QueueEnqueueV2
|
||||
QueueEnqueueMany
|
||||
QueueEnqueueManyV2
|
||||
QueueSize
|
||||
QueueSizeV2
|
||||
RandomShuffleQueue
|
||||
RandomShuffleQueueV2
|
||||
Stack
|
||||
StackClose
|
||||
StackPop
|
||||
StackPush
|
||||
StackV2
|
||||
StackCloseV2
|
||||
StackPopV2
|
||||
StackPushV2
|
||||
TensorArray
|
||||
TensorArrayClose
|
||||
TensorArrayCloseV2
|
||||
TensorArrayConcat
|
||||
TensorArrayConcatV2
|
||||
TensorArrayGather
|
||||
TensorArrayGatherV2
|
||||
TensorArrayGrad
|
||||
TensorArrayGradV2
|
||||
TensorArrayPack
|
||||
TensorArrayPackV2
|
||||
TensorArrayRead
|
||||
TensorArrayReadV2
|
||||
TensorArrayScatter
|
||||
TensorArrayScatterV2
|
||||
TensorArraySize
|
||||
TensorArraySizeV2
|
||||
TensorArraySplit
|
||||
TensorArraySplitV2
|
||||
TensorArrayUnpack
|
||||
TensorArrayUnpackV2
|
||||
TensorArrayV2
|
||||
TensorArrayWrite
|
||||
TensorArrayWriteV2
|
||||
TensorArrayV3
|
||||
TensorArrayCloseV3
|
||||
TensorArrayConcatV3
|
||||
TensorArrayGatherV3
|
||||
TensorArrayGradV3
|
||||
TensorArrayReadV3
|
||||
TensorArrayPackV3
|
||||
TensorArrayScatterV3
|
||||
TensorArraySizeV3
|
||||
TensorArraySplitV3
|
||||
TensorArrayUnpackV3
|
||||
TensorArrayWriteV3
|
||||
|
||||
# functional_ops
|
||||
SymbolicGradient
|
||||
|
||||
# image_ops
|
||||
AdjustContrastv2
|
||||
NonMaxSuppression
|
||||
NonMaxSuppressionV2
|
||||
RandomCrop
|
||||
ResizeBilinearGrad
|
||||
ResizeBicubicGrad
|
||||
ResizeNearestNeighborGrad
|
||||
SampleDistortedBoundingBox
|
||||
SampleDistortedBoundingBoxV2
|
||||
ScaleImageGrad
|
||||
|
||||
# io_ops
|
||||
FixedLengthRecordReader
|
||||
IdentityReader
|
||||
ReaderNumRecordsProduced
|
||||
ReaderNumWorkUnitsCompleted
|
||||
ReaderRead
|
||||
ReaderReadUpTo
|
||||
ReaderReset
|
||||
ReaderRestoreState
|
||||
ReaderSerializeState
|
||||
ReaderWorkQueueLength
|
||||
FixedLengthRecordReaderV2
|
||||
IdentityReaderV2
|
||||
ReaderNumRecordsProducedV2
|
||||
ReaderNumWorkUnitsCompletedV2
|
||||
ReaderReadV2
|
||||
ReaderReadUpToV2
|
||||
ReaderResetV2
|
||||
ReaderRestoreStateV2
|
||||
ReaderSerializeStateV2
|
||||
ReaderWorkQueueLengthV2
|
||||
Restore
|
||||
RestoreSlice
|
||||
Save
|
||||
SaveSlices
|
||||
ShardedFilename
|
||||
ShardedFilespec
|
||||
TextLineReader
|
||||
TFRecordReader
|
||||
WholeFileReader
|
||||
TextLineReaderV2
|
||||
TFRecordReaderV2
|
||||
WholeFileReaderV2
|
||||
LMDBReader
|
||||
DecodeCSV
|
||||
|
||||
# linalg_ops
|
||||
BatchCholesky
|
||||
BatchCholeskyGrad
|
||||
BatchMatrixDeterminant
|
||||
BatchMatrixInverse
|
||||
BatchMatrixSolve
|
||||
BatchMatrixSolveLs
|
||||
BatchMatrixTriangularSolve
|
||||
BatchSelfAdjointEig
|
||||
BatchSelfAdjointEigV2
|
||||
BatchSvd
|
||||
LogMatrixDeterminant
|
||||
MatrixExponential
|
||||
MatrixLogarithm
|
||||
MatrixSolveLs
|
||||
SelfAdjointEig
|
||||
SelfAdjointEigV2
|
||||
Svd
|
||||
|
||||
# logging_ops
|
||||
Assert
|
||||
AudioSummary
|
||||
AudioSummaryV2
|
||||
HistogramSummary
|
||||
ImageSummary
|
||||
MergeSummary
|
||||
Print
|
||||
ScalarSummary
|
||||
TensorSummary
|
||||
TensorSummaryV2
|
||||
|
||||
# math_ops
|
||||
Abs
|
||||
AccumulateNV2
|
||||
AddN
|
||||
AddV2
|
||||
All
|
||||
Any
|
||||
BatchMatMul
|
||||
BatchFFT
|
||||
BatchFFT2D
|
||||
BatchFFT3D
|
||||
BatchIFFT
|
||||
BatchIFFT2D
|
||||
BatchIFFT3D
|
||||
Bucketize
|
||||
ClipByValue
|
||||
Complex
|
||||
ComplexAbs
|
||||
Conj
|
||||
FloorDiv
|
||||
FloorMod
|
||||
HistogramFixedWidth
|
||||
Max
|
||||
Mean
|
||||
Min
|
||||
Mul
|
||||
Neg
|
||||
Pow
|
||||
Prod
|
||||
Range
|
||||
RealDiv
|
||||
Select
|
||||
SparseMatMul
|
||||
Sub
|
||||
Sum
|
||||
MatMul
|
||||
Sigmoid
|
||||
Tanh
|
||||
SigmoidGrad
|
||||
TanhGrad
|
||||
InvGrad
|
||||
ReciprocalGrad
|
||||
SqrtGrad
|
||||
RsqrtGrad
|
||||
TruncateDiv
|
||||
TruncateMod
|
||||
|
||||
# nn_ops
|
||||
AvgPoolGrad # "*Grad" accessible through nn_grad instead of nn_ops.
|
||||
AvgPool3DGrad
|
||||
BatchNormWithGlobalNormalization
|
||||
BatchNormWithGlobalNormalizationGrad
|
||||
FusedBatchNorm
|
||||
FusedBatchNormV2
|
||||
SoftmaxCrossEntropyWithLogits
|
||||
SparseSoftmaxCrossEntropyWithLogits
|
||||
LRNGrad
|
||||
MaxPoolGrad
|
||||
MaxPoolGradWithArgmax
|
||||
MaxPoolGradGrad
|
||||
MaxPoolGradGradWithArgmax
|
||||
MaxPool3DGrad
|
||||
MaxPool3DGradGrad
|
||||
ReluGrad
|
||||
Relu6Grad
|
||||
EluGrad
|
||||
SeluGrad
|
||||
SoftplusGrad
|
||||
SoftsignGrad
|
||||
TopK
|
||||
TopKV2
|
||||
BiasAdd
|
||||
BiasAddV1
|
||||
Relu6
|
||||
AvgPool
|
||||
MaxPool
|
||||
MaxPoolV2
|
||||
Softmax
|
||||
LogSoftmax
|
||||
FractionalAvgPoolGrad
|
||||
FractionalMaxPoolGrad
|
||||
InTopK
|
||||
InTopKV2
|
||||
|
||||
# parsing_ops
|
||||
ParseExample
|
||||
ParseSingleSequenceExample
|
||||
|
||||
# random_ops
|
||||
RandomGamma
|
||||
RandomPoisson
|
||||
RandomUniform
|
||||
RandomUniformInt
|
||||
RandomShuffle
|
||||
RandomStandardNormal
|
||||
ParameterizedTruncatedNormal
|
||||
TruncatedNormal
|
||||
|
||||
# script_ops
|
||||
PyFunc
|
||||
PyFuncStateless
|
||||
EagerPyFunc
|
||||
|
||||
# sdca_ops
|
||||
|
||||
# state_ops
|
||||
Variable
|
||||
VariableV2
|
||||
TemporaryVariable
|
||||
DestroyTemporaryVariable
|
||||
|
||||
# sparse_ops
|
||||
AddSparseToTensorsMap
|
||||
AddManySparseToTensorsMap
|
||||
TakeManySparseFromTensorsMap
|
||||
DeserializeManySparse
|
||||
DeserializeSparse
|
||||
SerializeManySparse
|
||||
SerializeSparse
|
||||
SparseAdd
|
||||
SparseAddGrad
|
||||
SparseConcat
|
||||
SparseCross
|
||||
SparseFillEmptyRows
|
||||
SparseFillEmptyRowsGrad
|
||||
SparseSplit
|
||||
SparseSelectLastK
|
||||
SparseReorder
|
||||
SparseReshape
|
||||
SparseToDense
|
||||
SparseTensorDenseAdd
|
||||
SparseTensorDenseMatMul
|
||||
|
||||
# string_ops
|
||||
StringSplit
|
||||
|
||||
# user_ops
|
||||
Fact
|
||||
|
||||
# training_ops
|
||||
# (None)
|
||||
|
||||
# word2vec deprecated ops
|
||||
NegTrain
|
||||
Skipgram
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import inspect as _inspect
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -24,6 +25,15 @@ from tensorflow.python.util import tf_decorator
|
||||
ArgSpec = _inspect.ArgSpec
|
||||
|
||||
|
||||
if hasattr(_inspect, 'FullArgSpec'):
|
||||
FullArgSpec = _inspect.FullArgSpec # pylint: disable=invalid-name
|
||||
else:
|
||||
FullArgSpec = namedtuple('FullArgSpec', [
|
||||
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
|
||||
'annotations'
|
||||
])
|
||||
|
||||
|
||||
def currentframe():
|
||||
"""TFDecorator-aware replacement for inspect.currentframe."""
|
||||
return _inspect.stack()[1][0]
|
||||
@ -55,13 +65,36 @@ def getfullargspec(obj): # pylint: disable=redefined-builtin
|
||||
obj: A callable, possibly decorated.
|
||||
|
||||
Returns:
|
||||
The `FullArgSpec` (`ArgSpec` in Python 2) that describes the signature of
|
||||
The `FullArgSpec` that describes the signature of
|
||||
the outermost decorator that changes the callable's signature. If the
|
||||
callable is not decorated, `inspect.getfullargspec()`
|
||||
(`inspect.getargspec()` in Python 2) will be called directly on the
|
||||
callable.
|
||||
callable is not decorated, `inspect.getfullargspec()` will be called
|
||||
directly on the callable.
|
||||
"""
|
||||
spec_fn = getattr(_inspect, 'getfullargspec', getattr(_inspect, 'getargspec'))
|
||||
if hasattr(_inspect, 'getfullargspec'):
|
||||
spec_fn = _inspect.getfullargspec
|
||||
else:
|
||||
def spec_fn(target):
|
||||
"""Spec function that adding default value from FullArgSpec.
|
||||
|
||||
It is used when getfullargspec is not available (eg in PY2).
|
||||
|
||||
Args:
|
||||
target: the target object to inspect.
|
||||
Returns:
|
||||
The full argument specs with empty kwonlyargs, kwonlydefaults and
|
||||
annotations.
|
||||
"""
|
||||
argspecs = _inspect.getargspec(target)
|
||||
fullargspecs = FullArgSpec(
|
||||
args=argspecs.args,
|
||||
varargs=argspecs.varargs,
|
||||
varkw=argspecs.keywords,
|
||||
defaults=argspecs.defaults,
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
return fullargspecs
|
||||
|
||||
decorators, target = tf_decorator.unwrap(obj)
|
||||
return next((d.decorator_argspec for d in decorators
|
||||
if d.decorator_argspec is not None), spec_fn(target))
|
||||
|
@ -810,7 +810,33 @@ def tf_cc_test_mkl(srcs,
|
||||
tags=[],
|
||||
size="medium",
|
||||
args=None):
|
||||
if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
|
||||
for src in srcs:
|
||||
native.cc_test(
|
||||
name=src_to_test_name(src),
|
||||
srcs=if_mkl([src]) + tf_binary_additional_srcs(),
|
||||
copts=tf_copts(),
|
||||
linkopts=select({
|
||||
clean_dep("//tensorflow:android"): [
|
||||
"-pie",
|
||||
],
|
||||
clean_dep("//tensorflow:windows"): [],
|
||||
clean_dep("//tensorflow:windows_msvc"): [],
|
||||
"//conditions:default": [
|
||||
"-lpthread",
|
||||
"-lm"
|
||||
],
|
||||
}) + _rpath_linkopts(src_to_test_name(src)),
|
||||
deps=deps + if_mkl(
|
||||
[
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
),
|
||||
linkstatic=linkstatic,
|
||||
tags=tags,
|
||||
size=size,
|
||||
args=args,
|
||||
nocopts="-fno-exceptions")
|
||||
|
||||
|
||||
def tf_cc_tests_gpu(srcs,
|
||||
deps,
|
||||
@ -1029,16 +1055,12 @@ register_extension_info(
|
||||
def tf_mkl_kernel_library(name,
|
||||
prefix=None,
|
||||
srcs=None,
|
||||
gpu_srcs=None,
|
||||
hdrs=None,
|
||||
deps=None,
|
||||
alwayslink=1,
|
||||
copts=tf_copts(),
|
||||
nocopts="-fno-exceptions",
|
||||
**kwargs):
|
||||
nocopts="-fno-exceptions"):
|
||||
"""A rule to build MKL-based TensorFlow kernel libraries."""
|
||||
gpu_srcs = gpu_srcs # unused argument
|
||||
kwargs = kwargs # unused argument
|
||||
|
||||
if not bool(srcs):
|
||||
srcs = []
|
||||
@ -1051,16 +1073,15 @@ def tf_mkl_kernel_library(name,
|
||||
hdrs = hdrs + native.glob(
|
||||
[prefix + "*.h"])
|
||||
|
||||
if_mkl(
|
||||
native.cc_library(
|
||||
name=name,
|
||||
srcs=srcs,
|
||||
hdrs=hdrs,
|
||||
deps=deps,
|
||||
alwayslink=alwayslink,
|
||||
copts=copts,
|
||||
nocopts=nocopts
|
||||
))
|
||||
native.cc_library(
|
||||
name=name,
|
||||
srcs=if_mkl(srcs),
|
||||
hdrs=hdrs,
|
||||
deps=deps,
|
||||
alwayslink=alwayslink,
|
||||
copts=copts,
|
||||
nocopts=nocopts
|
||||
)
|
||||
|
||||
register_extension_info(
|
||||
extension_name = "tf_mkl_kernel_library",
|
||||
|
@ -160,7 +160,8 @@ def get_api_init_text():
|
||||
# we want to traverse over TensorFlow Python modules.
|
||||
for module in sys.modules.values():
|
||||
# Only look at tensorflow modules.
|
||||
if not module or 'tensorflow.' not in module.__name__:
|
||||
if (not module or not hasattr(module, '__name__') or
|
||||
'tensorflow.' not in module.__name__):
|
||||
continue
|
||||
# Do not generate __init__.py files for contrib modules for now.
|
||||
if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
|
||||
|
@ -47,7 +47,7 @@ RUN pip --no-cache-dir install \
|
||||
http://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.0.0-cp27-none-linux_x86_64.whl
|
||||
# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #
|
||||
|
||||
# RUN ln -s /usr/bin/python3 /usr/bin/python#
|
||||
# RUN ln -s -f /usr/bin/python3 /usr/bin/python#
|
||||
|
||||
# Set up our notebook config.
|
||||
COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
|
@ -38,6 +38,8 @@ RUN pip --no-cache-dir install \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
# RUN ln -s -f /usr/bin/python3 /usr/bin/python#
|
||||
|
||||
# Set up our notebook config.
|
||||
COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
|
||||
|
@ -47,6 +47,8 @@ RUN pip --no-cache-dir install \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
# RUN ln -s -f /usr/bin/python3 /usr/bin/python#
|
||||
|
||||
# Set up our notebook config.
|
||||
COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
|
||||
|
@ -54,7 +54,7 @@ RUN pip --no-cache-dir install \
|
||||
http://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.0.0-cp27-none-linux_x86_64.whl
|
||||
# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #
|
||||
|
||||
# RUN ln -s /usr/bin/python3 /usr/bin/python#
|
||||
# RUN ln -s -f /usr/bin/python3 /usr/bin/python#
|
||||
|
||||
# Set up our notebook config.
|
||||
COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
|
@ -1207,7 +1207,7 @@
|
||||
"source": [
|
||||
"# Training computation: logits + cross-entropy loss.\n",
|
||||
"logits = model(train_data_node, True)\n",
|
||||
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
|
||||
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(\n",
|
||||
" labels=train_labels_node, logits=logits))\n",
|
||||
"\n",
|
||||
"# L2 regularization for the fully connected parameters.\n",
|
||||
@ -2031,7 +2031,7 @@
|
||||
"views": {}
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [default]",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -2049,5 +2049,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
|
||||
if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
|
||||
sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
|
||||
then
|
||||
echo "Modified Dockerfile for python version "\
|
||||
"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
|
||||
@ -306,7 +306,7 @@ else
|
||||
sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \
|
||||
sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
|
||||
sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
|
||||
then
|
||||
echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
|
||||
else
|
||||
|
@ -37,7 +37,7 @@ py_library(
|
||||
srcs = ["parser.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@com_github_andreif_codegen"],
|
||||
deps = ["@astor_archive//:astor"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
import tensorflow as tf
|
||||
@ -39,10 +38,6 @@ class Flags(object):
|
||||
class BuildDocsTest(googletest.TestCase):
|
||||
|
||||
def testBuildDocs(self):
|
||||
if sys.version_info >= (3, 0):
|
||||
print('Warning: Doc generation is not supported from python3.')
|
||||
return
|
||||
|
||||
doc_generator = generate_lib.DocGenerator()
|
||||
|
||||
doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)])
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import fnmatch
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
@ -134,8 +133,12 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'):
|
||||
try:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
with open(path, 'w') as f:
|
||||
f.write(pretty_docs.build_md_page(page_info))
|
||||
# This function returns raw bytes in PY2 or unicode in PY3.
|
||||
text = pretty_docs.build_md_page(page_info)
|
||||
if six.PY3:
|
||||
text = text.encode('utf-8')
|
||||
with open(path, 'wb') as f:
|
||||
f.write(text)
|
||||
except OSError as e:
|
||||
print('Cannot write documentation for %s to %s: %s' % (full_name,
|
||||
directory, e))
|
||||
@ -437,19 +440,19 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
|
||||
full_out_path = os.path.join(output_dir, suffix)
|
||||
if not fnmatch.fnmatch(base_name, file_pattern):
|
||||
print('Copying un-matched file %s...' % suffix)
|
||||
open(full_out_path, 'w').write(open(full_in_path).read())
|
||||
open(full_out_path, 'wb').write(open(full_in_path, 'rb').read())
|
||||
continue
|
||||
if dirpath.endswith('/api_guides/python'):
|
||||
print('Processing Python guide %s...' % base_name)
|
||||
content = tag_updater.process(full_in_path)
|
||||
else:
|
||||
print('Processing doc %s...' % suffix)
|
||||
content = open(full_in_path).read()
|
||||
content = open(full_in_path, 'rb').read().decode('utf-8')
|
||||
|
||||
content = reference_resolver.replace_references(content,
|
||||
relative_path_to_root)
|
||||
with open(full_out_path, 'w') as f:
|
||||
f.write(content)
|
||||
with open(full_out_path, 'wb') as f:
|
||||
f.write(content.encode('utf-8'))
|
||||
|
||||
print('Done.')
|
||||
|
||||
@ -458,8 +461,6 @@ class DocGenerator(object):
|
||||
"""Main entry point for generating docs."""
|
||||
|
||||
def __init__(self):
|
||||
if sys.version_info >= (3, 0):
|
||||
sys.exit('Doc generation is not supported from python3.')
|
||||
self.argument_parser = argparse.ArgumentParser()
|
||||
self._py_modules = None
|
||||
self._private_map = _get_default_private_map()
|
||||
|
@ -52,9 +52,6 @@ class DummyVisitor(object):
|
||||
class GenerateTest(googletest.TestCase):
|
||||
|
||||
def test_write(self):
|
||||
if sys.version_info >= (3, 0):
|
||||
self.skipTest('Warning: Doc generation is not supported from python3.')
|
||||
|
||||
module = sys.modules[__name__]
|
||||
|
||||
index = {
|
||||
|
@ -26,7 +26,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import codegen
|
||||
import astor
|
||||
import six
|
||||
|
||||
from google.protobuf.message import Message as ProtoMessage
|
||||
@ -621,20 +621,20 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver):
|
||||
def _get_arg_spec(func):
|
||||
"""Extracts signature information from a function or functools.partial object.
|
||||
|
||||
For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects,
|
||||
corrects the signature of the underlying function to take into account the
|
||||
removed arguments.
|
||||
For functions, uses `tf_inspect.getfullargspec`. For `functools.partial`
|
||||
objects, corrects the signature of the underlying function to take into
|
||||
account the removed arguments.
|
||||
|
||||
Args:
|
||||
func: A function whose signature to extract.
|
||||
|
||||
Returns:
|
||||
An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned
|
||||
by `tf_inspect.getargspec`.
|
||||
An `FullArgSpec` namedtuple `(args, varargs, varkw, defaults, etc.)`,
|
||||
as returned by `tf_inspect.getfullargspec`.
|
||||
"""
|
||||
# getargspec does not work for functools.partial objects directly.
|
||||
# getfullargspec does not work for functools.partial objects directly.
|
||||
if isinstance(func, functools.partial):
|
||||
argspec = tf_inspect.getargspec(func.func)
|
||||
argspec = tf_inspect.getfullargspec(func.func)
|
||||
# Remove the args from the original function that have been used up.
|
||||
first_default_arg = (
|
||||
len(argspec.args or []) - len(argspec.defaults or []))
|
||||
@ -657,12 +657,16 @@ def _get_arg_spec(func):
|
||||
argspec_defaults.pop(i-first_default_arg)
|
||||
else:
|
||||
first_default_arg -= 1
|
||||
return tf_inspect.ArgSpec(args=argspec_args,
|
||||
varargs=argspec.varargs,
|
||||
keywords=argspec.keywords,
|
||||
defaults=tuple(argspec_defaults))
|
||||
return tf_inspect.FullArgSpec(
|
||||
args=argspec_args,
|
||||
varargs=argspec.varargs,
|
||||
varkw=argspec.varkw,
|
||||
defaults=tuple(argspec_defaults),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
else: # Regular function or method, getargspec will work fine.
|
||||
return tf_inspect.getargspec(func)
|
||||
return tf_inspect.getfullargspec(func)
|
||||
|
||||
|
||||
def _remove_first_line_indent(string):
|
||||
@ -670,11 +674,14 @@ def _remove_first_line_indent(string):
|
||||
return '\n'.join([line[indent:] for line in string.split('\n')])
|
||||
|
||||
|
||||
PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)')
|
||||
|
||||
|
||||
def _generate_signature(func, reverse_index):
|
||||
"""Given a function, returns a list of strings representing its args.
|
||||
|
||||
This function produces a list of strings representing the arguments to a
|
||||
python function. It uses tf_inspect.getargspec, which
|
||||
python function. It uses tf_inspect.getfullargspec, which
|
||||
does not generalize well to Python 3.x, which is more flexible in how *args
|
||||
and **kwargs are handled. This is not a problem in TF, since we have to remain
|
||||
compatible to Python 2.7 anyway.
|
||||
@ -725,7 +732,11 @@ def _generate_signature(func, reverse_index):
|
||||
if id(default) in reverse_index:
|
||||
default_text = reverse_index[id(default)]
|
||||
elif ast_default is not None:
|
||||
default_text = codegen.to_source(ast_default)
|
||||
default_text = (
|
||||
astor.to_source(ast_default).rstrip('\n').replace('\t', '\\t')
|
||||
.replace('\n', '\\n').replace('"""', "'"))
|
||||
default_text = PAREN_NUMBER_RE.sub('\\1', default_text)
|
||||
|
||||
if default_text != repr(default):
|
||||
# This may be an internal name. If so, handle the ones we know about.
|
||||
# TODO(wicke): This should be replaced with a lookup in the index.
|
||||
@ -758,8 +769,8 @@ def _generate_signature(func, reverse_index):
|
||||
# Add *args and *kwargs.
|
||||
if argspec.varargs:
|
||||
args_list.append('*' + argspec.varargs)
|
||||
if argspec.keywords:
|
||||
args_list.append('**' + argspec.keywords)
|
||||
if argspec.varkw:
|
||||
args_list.append('**' + argspec.varkw)
|
||||
|
||||
return args_list
|
||||
|
||||
@ -1136,9 +1147,11 @@ class _ClassPageInfo(object):
|
||||
|
||||
for short_name in parser_config.tree[self.full_name]:
|
||||
# Remove builtin members that we never want to document.
|
||||
if short_name in ['__class__', '__base__', '__weakref__', '__doc__',
|
||||
'__module__', '__dict__', '__abstractmethods__',
|
||||
'__slots__', '__getnewargs__']:
|
||||
if short_name in [
|
||||
'__class__', '__base__', '__weakref__', '__doc__', '__module__',
|
||||
'__dict__', '__abstractmethods__', '__slots__', '__getnewargs__',
|
||||
'__str__', '__repr__', '__hash__'
|
||||
]:
|
||||
continue
|
||||
|
||||
child_name = '.'.join([self.full_name, short_name])
|
||||
@ -1183,7 +1196,8 @@ class _ClassPageInfo(object):
|
||||
# obvious what they do, don't include them in the docs if there's no
|
||||
# docstring.
|
||||
if not child_doc.brief.strip() and short_name in [
|
||||
'__str__', '__repr__', '__hash__', '__del__', '__copy__']:
|
||||
'__del__', '__copy__'
|
||||
]:
|
||||
print('Skipping %s, defined in %s, no docstring.' % (child_name,
|
||||
defining_class))
|
||||
continue
|
||||
|
@ -398,7 +398,6 @@ class ParserTest(googletest.TestCase):
|
||||
self.assertIn('<code>test_function', docs)
|
||||
|
||||
def test_argspec_for_functools_partial(self):
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_function_for_partial1(arg1, arg2, kwarg1=1, kwarg2=2):
|
||||
pass
|
||||
@ -409,42 +408,95 @@ class ParserTest(googletest.TestCase):
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Make sure everything works for regular functions.
|
||||
expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
|
||||
None, (1, 2))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(1, 2),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1))
|
||||
|
||||
# Make sure doing nothing works.
|
||||
expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
|
||||
None, (1, 2))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(1, 2),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
# Make sure setting args from the front works.
|
||||
expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None,
|
||||
(1, 2))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg2', 'kwarg1', 'kwarg2'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(1, 2),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1, 1)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['kwarg2'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(2,),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1, 1, 2, 3)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
# Make sure setting kwargs works.
|
||||
expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg1', 'arg2', 'kwarg2'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(2,),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1, kwarg1=0)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,))
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg1', 'arg2', 'kwarg1'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(1,),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1, kwarg2=0)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
expected = tf_inspect.ArgSpec(['arg1'], None, None, ())
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=['arg1'],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=(),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial1,
|
||||
arg2=0, kwarg1=0, kwarg2=0)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
# Make sure *args, *kwargs is accounted for.
|
||||
expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ())
|
||||
expected = tf_inspect.FullArgSpec(
|
||||
args=[],
|
||||
varargs='my_args',
|
||||
varkw='my_kwargs',
|
||||
defaults=(),
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults=None,
|
||||
annotations={})
|
||||
partial = functools.partial(test_function_for_partial2, 0, 1)
|
||||
self.assertEqual(expected, parser._get_arg_spec(partial))
|
||||
|
||||
@ -524,10 +576,6 @@ class TestParseFunctionDetails(googletest.TestCase):
|
||||
class TestGenerateSignature(googletest.TestCase):
|
||||
|
||||
def test_known_object(self):
|
||||
if sys.version_info >= (3, 0):
|
||||
print('Warning: Doc generation is not supported from python3.')
|
||||
return
|
||||
|
||||
known_object = object()
|
||||
reverse_index = {id(known_object): 'location.of.object.in.api'}
|
||||
|
||||
|
@ -101,7 +101,7 @@ def _build_class_page(page_info):
|
||||
|
||||
link_template = '[`{short_name}`]({url})'
|
||||
parts.append(', '.join(
|
||||
link_template.format(**base.__dict__) for base in page_info.bases))
|
||||
link_template.format(**base._asdict()) for base in page_info.bases))
|
||||
|
||||
parts.append('\n\n')
|
||||
|
||||
@ -159,7 +159,7 @@ def _build_class_page(page_info):
|
||||
h3 = ('<h3 id="{short_name}">'
|
||||
'<code>{short_name}</code>'
|
||||
'</h3>\n\n')
|
||||
parts.append(h3.format(**method_info.__dict__))
|
||||
parts.append(h3.format(**method_info._asdict()))
|
||||
|
||||
if method_info.signature is not None:
|
||||
parts.append(_build_signature(method_info, use_full_name=False))
|
||||
@ -217,7 +217,7 @@ def _build_module_page(page_info):
|
||||
template = '[`{short_name}`]({url}) module'
|
||||
|
||||
for item in page_info.modules:
|
||||
parts.append(template.format(**item.__dict__))
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
@ -229,7 +229,7 @@ def _build_module_page(page_info):
|
||||
template = '[`class {short_name}`]({url})'
|
||||
|
||||
for item in page_info.classes:
|
||||
parts.append(template.format(**item.__dict__))
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
@ -241,7 +241,7 @@ def _build_module_page(page_info):
|
||||
template = '[`{short_name}(...)`]({url})'
|
||||
|
||||
for item in page_info.functions:
|
||||
parts.append(template.format(**item.__dict__))
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
@ -254,7 +254,7 @@ def _build_module_page(page_info):
|
||||
parts.append('## Other Members\n\n')
|
||||
|
||||
for item in page_info.other_members:
|
||||
parts.append('`{short_name}`\n\n'.format(**item.__dict__))
|
||||
parts.append('`{short_name}`\n\n'.format(**item._asdict()))
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
|
@ -44,7 +44,7 @@ class PyGuideParser(object):
|
||||
|
||||
def process(self, full_path):
|
||||
"""Read and process the file at `full_path`."""
|
||||
md_string = open(full_path).read()
|
||||
md_string = open(full_path, 'rb').read().decode('utf-8')
|
||||
self._lines = md_string.split('\n')
|
||||
seen = set()
|
||||
|
||||
|
@ -315,18 +315,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
strip_prefix = "backports.weakref-1.0rc1/src",
|
||||
build_file = clean_dep("//third_party:backports_weakref.BUILD"),
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "com_github_andreif_codegen",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz",
|
||||
"https://github.com/andreif/codegen/archive/1.0.tar.gz",
|
||||
],
|
||||
sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee",
|
||||
strip_prefix = "codegen-1.0",
|
||||
build_file = clean_dep("//third_party:codegen.BUILD"),
|
||||
)
|
||||
|
||||
|
||||
filegroup_external(
|
||||
name = "org_python_license",
|
||||
licenses = ["notice"], # Python 2.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user