Merge pull request #3124 from rmlarsen/branch_126308395

Branch 126308395
This commit is contained in:
Vijay Vasudevan 2016-06-30 11:31:49 -07:00 committed by GitHub
commit ac90ecb08d
55 changed files with 1976 additions and 749 deletions

View File

@ -77,17 +77,6 @@ py_test(
],
)
cuda_py_test(
name = "embedding_ops_test",
size = "small",
srcs = ["python/ops/embedding_ops_test.py"],
additional_deps = [
":framework_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "sampling_ops_test",
size = "small",

View File

@ -213,6 +213,19 @@ py_test(
],
)
py_test(
name = "embedding_ops_test",
size = "small",
srcs = ["python/layers/embedding_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.layers.python.layers.embedding_ops import *
from tensorflow.contrib.layers.python.layers.feature_column import *
from tensorflow.contrib.layers.python.layers.feature_column_ops import *
from tensorflow.contrib.layers.python.layers.initializers import *

View File

@ -0,0 +1,191 @@
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Embedding functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
__all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup",
"hashed_embedding_lookup_sparse"]
# TODO(chapelle): move the safe_embedding_lookup_sparse code here (b/29826543)
safe_embedding_lookup_sparse = contrib_embedding_ops.safe_embedding_lookup_sparse # pylint: disable=line-too-long
def hashed_embedding_lookup(params, values, dimension, name=None):
"""Looks up embeddings using parameter hashing for each value in `values`.
The i-th embedding component of a value v in `values` is found by retrieving
the weight whose index is a fingerprint of the pair (v,i).
The concept is explored as "feature hashing" for model compression in this
paper: http://arxiv.org/pdf/1504.04788.pdf
Feature hashing has the pleasant effect of allowing us to compute an embedding
without needing a pre-determined vocabulary, relieving some amount of process
complexity. It also allows for us to maintain embeddings for possibly
trillions of features with a fixed amount of memory.
Note that this is superior to out-of-vocabulary shared "hash buckets" in that
the embedding is extremely likely to be unique for each token as opposed to
being shared across probably-colliding tokens. The price is that we must
compute a hash once for each scalar in the token's embedding as opposed to
once per token.
If `params` is a list, it represents a partition of the embedding parameters.
Each tensor in the list should have the same length, except for the first ones
which may have an additional element. For instance 10 parameters can be
partitioned in 4 tensors with length `[3, 3, 2, 2]`.
Args:
params: A `Tensor` or `list` of `Tensors`.
Each tensor must be of rank 1 with fully-defined shape.
values: `Tensor` of values to be embedded.
dimension: Embedding dimension
name: An optional name for this op.
Returns:
A tensor with shape [d0, ..., dn, dimension]
with shape(values) = [d0, ..., dn]
Raises:
ValueError: if dimension is not positive or the partition size is invalid.
"""
if not isinstance(params, list):
params = [params]
with ops.op_scope(params + [dimension, values], name,
"hashed_embedding_lookup"):
if dimension <= 0:
raise ValueError("Dimension should be >0 not %d" % dimension)
num_partitions = len(params)
partition_sizes = []
for p in range(num_partitions):
shape = params[p].get_shape()
shape.assert_has_rank(1)
shape.assert_is_fully_defined()
partition_sizes.append(shape[0].value)
num_params = sum(partition_sizes) # Total number of parameters.
# Assert the size of each partition.
for p in range(num_partitions):
expected_size = (num_params - p - 1) // num_partitions + 1
if partition_sizes[p] != expected_size:
raise ValueError("Tensor %d in params has size %d, expected %d." %
(p, partition_sizes[p], expected_size))
# Flatten the values
values_shape = array_ops.shape(values)
values = array_ops.reshape(values, [-1, 1])
# With two values v1 and v2 and 3 dimensions, we will cross
# [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]].
tensors_to_cross = [array_ops.tile(array_ops.expand_dims(
math_ops.range(0, dimension), 0), array_ops.shape(values)), values]
ids = sparse_feature_cross_op.sparse_feature_cross(
tensors_to_cross, hashed_output=True, num_buckets=num_params)
ids = sparse_ops.sparse_tensor_to_dense(ids)
# No need to validate the indices since we have checked the params
# dimensions and we know the largest id.
result = embedding_ops.embedding_lookup(
params, ids, partition_strategy="div", validate_indices=False)
return array_ops.reshape(result, array_ops.concat(
0, [values_shape, [dimension]]))
def hashed_embedding_lookup_sparse(params,
sparse_values,
dimension,
combiner="mean",
default_value=None,
name=None):
"""Looks up embeddings of a sparse feature using parameter hashing.
See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing.
Args:
params: A `Tensor` or `list` of `Tensors`.
Each tensor must be of rank 1 with fully-defined shape.
sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
Some rows may be empty.
dimension: Embedding dimension
combiner: A string specifying how to combine embedding results for each
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
the default.
default_value: The value to use for an entry with no features.
name: An optional name for this op.
Returns:
Dense tensor with shape [N, dimension] with N the number of rows in
sparse_values.
Raises:
TypeError: If sparse_values is not a SparseTensor.
ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
"""
if not isinstance(params, list):
params = [params]
if not isinstance(sparse_values, ops.SparseTensor):
raise TypeError("sparse_values must be SparseTensor")
with ops.op_scope(params + [sparse_values], name,
"hashed_sparse_embedding_lookup") as scope:
# Fill in the empty rows.
if default_value is None:
# Random default values to reduce the risk of collision.
if sparse_values.dtype == dtypes.string:
default_value = "6ZxWzWOHxZ"
else:
default_value = 1288896567
sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
sparse_values, default_value)
segment_ids = sparse_values.indices[:, 0]
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
values = sparse_values.values
values, idx = array_ops.unique(values)
embeddings = hashed_embedding_lookup(params, values, dimension)
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
name=scope)
elif combiner == "mean":
embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
name=scope)
elif combiner == "sqrtn":
embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
name=scope)
else:
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
return embeddings

View File

@ -100,7 +100,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
self.assertAllClose(embedding_lookup_result, [
@ -114,7 +114,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
@ -132,9 +132,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
self.assertAllClose(embedding_lookup_result, [
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@ -148,9 +148,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
self.assertAllClose(embedding_lookup_result, [
@ -166,13 +166,13 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights
]
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
@ -181,7 +181,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
self.assertAllClose(embedding_lookup_result, [
@ -195,7 +195,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
@ -214,9 +214,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
self.assertAllClose(embedding_lookup_result, [
[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@ -231,9 +231,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
self.assertAllClose(embedding_lookup_result, [
@ -251,15 +251,110 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights
]
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
class HashedEmbeddingLookupTest(tf.test.TestCase):
def setUp(self):
tf.set_random_seed(1)
def _random_weights(self, size=50, num_shards=1):
assert size > 0
assert num_shards > 0
assert num_shards <= size
embedding_weights = tf.create_partitioned_variables(
shape=[size],
slicing=[num_shards],
initializer=tf.truncated_normal_initializer(mean=0.0,
stddev=1.0,
dtype=tf.float32))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_hashed_embedding_consistency(self):
with self.test_session():
embedding_weights = self._random_weights()
values = tf.constant(["foo", "foo"])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=10).eval()
self.assertAllEqual(embedding_lookup_result.shape, [2, 10])
self.assertAllEqual(embedding_lookup_result[0],
embedding_lookup_result[1])
def test_hashed_embedding_multiple_partition(self):
with self.test_session():
embedding_weights = self._random_weights(num_shards=7)
values = tf.constant([4, 4, 5])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=5).eval()
self.assertAllEqual(embedding_lookup_result.shape, [3, 5])
self.assertAllEqual(embedding_lookup_result[0],
embedding_lookup_result[1])
# Different embedding expected for different value.
embedding_diff = np.min((embedding_lookup_result[2] -
embedding_lookup_result[0]) ** 2)
self.assertGreater(embedding_diff, 0)
def test_hashed_embedding_coverage(self):
with self.test_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = tf.constant(["foo"])
# Large embedding dimension to cover the full range of weights.
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=100).eval()
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_hashed_embedding_multi_dimension(self):
with self.test_session():
embedding_weights = self._random_weights()
values = tf.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=10).eval()
self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 10])
self.assertAllEqual(embedding_lookup_result[0][0],
embedding_lookup_result[1][2])
def test_hashed_embedding_lookup_sparse(self):
with self.test_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = tf.SparseTensor(values=["foo", "bar", "foo", "bar"],
indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
shape=[5, 2])
embedding_lookup_result = (
tf.contrib.layers.hashed_embedding_lookup_sparse(
embedding_weights, sparse_tensor, dimension=5, combiner="mean")
.eval())
self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
# Same non-zero embedding for the empty rows filled with a default value.
self.assertAllEqual(embedding_lookup_result[2],
embedding_lookup_result[4])
embedding_norm = np.sum(embedding_lookup_result[2] ** 2)
self.assertGreater(embedding_norm, 0)
self.assertAllEqual(embedding_lookup_result[1],
0.5 * (embedding_lookup_result[0] +
embedding_lookup_result[3]))
if __name__ == "__main__":
tf.test.main()

View File

@ -75,7 +75,7 @@ import abc
import collections
import math
from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
@ -517,7 +517,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
# This is effectively the same format as str(self), except with our special
# treatment.
return "_EmbeddingColumn(%s)" % ", ".join(fields_values)
return "%s(%s)" % (type(self).__name__, ", ".join(fields_values))
def insert_transformed_feature(self, columns_to_tensors):
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
@ -576,6 +576,105 @@ def embedding_column(sparse_id_column,
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
class _HashedEmbeddingColumn(collections.namedtuple(
"_HashedEmbeddingColumn", ["column_name", "size", "dimension", "combiner",
"initializer"]), _EmbeddingColumn):
"""See `hashed_embedding_column`."""
def __new__(cls,
column_name,
size,
dimension,
combiner="mean",
initializer=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
if initializer is None:
stddev = 0.1
# TODO(b/25671353): Better initial value?
initializer = init_ops.truncated_normal_initializer(mean=0.0,
stddev=stddev)
return super(_HashedEmbeddingColumn, cls).__new__(cls, column_name, size,
dimension, combiner,
initializer)
@property
def name(self):
return self.column_name + "_embedding"
@property
def config(self):
return {self.column_name: parsing_ops.VarLenFeature(dtypes.string)}
def insert_transformed_feature(self, columns_to_tensors):
columns_to_tensors[self] = columns_to_tensors[self.column_name]
def to_dnn_input_layer(self,
input_tensor,
weight_collections=None,
trainable=True):
# Same heuristic for the number of shards as _max_size_embedding_partitioner
max_shard_bytes = (64 << 20) - 1
shards = self.size * 4.0 / max_shard_bytes
shards = max(1, int(math.ceil(shards)))
embeddings = partitioned_variables.create_partitioned_variables(
shape=[self.size],
slicing=[shards],
initializer=self.initializer,
dtype=dtypes.float32,
collections=_add_variable_collection(weight_collections),
name=self.name + "_weights",
reuse=False,
trainable=trainable)
return embedding_ops.hashed_embedding_lookup_sparse(
embeddings, input_tensor, self.dimension, name=self.name + "_lookup")
def hashed_embedding_column(column_name,
size,
dimension,
combiner="mean",
initializer=None):
"""Creates an embedding column of a sparse feature using parameter hashing.
The i-th embedding component of a value v is found by retrieving an
embedding weight whose index is a fingerprint of the pair (v,i).
Args:
column_name: A string defining sparse column name.
size: An integer specifying the number of parameters in the embedding layer.
dimension: An integer specifying dimension of the embedding.
combiner: A string specifying how to reduce if there are multiple entries
in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
of this can be thought as example level normalizations on the column:
* "sum": do not normalize features in the column
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
For more information: `tf.embedding_lookup_sparse`.
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0 and standard deviation 0.1.
Returns:
A _HashedEmbeddingColumn.
Raises:
ValueError: if dimension or size is not a positive integer; or if combiner
is not supported.
"""
if (dimension < 1) or (size < 1):
raise ValueError("Dimension and size must be greater than 0.")
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
initializer)
class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
"_RealValuedColumn",
["column_name", "dimension", "default_value", "dtype"])):
@ -1237,7 +1336,7 @@ def _create_embedding_lookup(input_tensor, vocab_size, dimension,
reuse=False,
trainable=trainable)
return contrib_embedding_ops.safe_embedding_lookup_sparse(
return embedding_ops.safe_embedding_lookup_sparse(
embeddings,
input_tensor,
default_id=0,

View File

@ -282,6 +282,27 @@ class InputLayerTest(tf.test.TestCase):
tf.initialize_all_variables().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testHashedEmbeddingColumn(self):
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo", "omar"],
indices=[[0, 0], [1, 0], [1, 1], [2, 0]],
shape=[3, 2])
features = {"wire": wire_tensor}
# Big enough hash space so that hopefully there is no collision
embedded_sparse = tf.contrib.layers.hashed_embedding_column("wire", 1000, 3)
output = tf.contrib.layers.input_from_feature_columns(
features, [embedded_sparse], weight_collections=["my_collection"])
weights = tf.get_collection("my_collection")
grad = tf.gradients(output, weights)
with self.test_session():
tf.initialize_all_variables().run()
gradient_values = []
# Collect the gradient from the different partitions (one in this test)
for p in range(len(grad)):
gradient_values.extend(grad[p].values.eval())
gradient_values.sort()
self.assertAllEqual(gradient_values, [0.5]*6 + [2]*3)
def testEmbeddingColumnWithInitializer(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],

View File

@ -42,7 +42,8 @@ class _BaseEstimator(object):
Args:
deep: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns:
@ -209,4 +210,3 @@ else:
log_loss = None
mean_squared_error = _mean_squared_error
train_test_split = _train_test_split

View File

@ -89,10 +89,14 @@ class TensorFlowEstimator(estimator.Estimator):
Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
clip_gradients: Clip norm of the gradients to this value to stop
gradient explosion.
class_weight: None or list of n_classes floats. Weight associated with
@ -103,9 +107,10 @@ class TensorFlowEstimator(estimator.Estimator):
config: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
"""
self.class_weight = class_weight
self.learning_rate = learning_rate

View File

@ -69,15 +69,16 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
"""
def __init__(self,
@ -219,15 +220,16 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
"""
def __init__(self,

View File

@ -42,9 +42,23 @@ class DNNClassifierTest(tf.test.TestCase):
classifier.fit(input_fn=_iris_input_fn, steps=1000)
classifier.evaluate(input_fn=_iris_input_fn, steps=100)
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
# TODO(ispir): Enable accuracy check after resolving the randomness issue.
# self.assertGreater(scores['accuracy/mean'], 0.6)
def testDisableCenteredBias(self):
"""Tests that we can disable centered bias."""
cont_features = [
tf.contrib.layers.real_valued_column('feature', dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(n_classes=3,
feature_columns=cont_features,
hidden_units=[3, 3],
enable_centered_bias=False)
classifier.fit(input_fn=_iris_input_fn, steps=1000)
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
class DNNRegressorTest(tf.test.TestCase):

View File

@ -102,7 +102,8 @@ def _get_arguments(func):
class BaseEstimator(sklearn.BaseEstimator):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
Concrete implementation of this class should provide following functions:
Concrete implementation of this class should provide the following functions:
* _get_train_ops
* _get_eval_ops
* _get_predict_ops
@ -165,12 +166,15 @@ class BaseEstimator(sklearn.BaseEstimator):
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
steps: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
batch_size: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
monitors: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
max_steps: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -180,8 +184,6 @@ class BaseEstimator(sklearn.BaseEstimator):
Raises:
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
ValueError: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
ValueError: If both `steps` and `max_steps` are not `None`.
"""
if (steps is not None) and (max_steps is not None):
@ -273,17 +275,19 @@ class BaseEstimator(sklearn.BaseEstimator):
provided.
steps: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
metrics: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
metrics: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
name: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
Returns:
Returns `dict` with evaluation results.
@ -656,20 +660,23 @@ class Estimator(BaseEstimator):
model_fn: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
Where:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) -> (predictions, loss, train_op)`
Where
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
model_dir: Directory to save model parameters, graph and etc.
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.

View File

@ -78,15 +78,16 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
"""
def __init__(self,
@ -229,15 +230,16 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
Input of `fit` and `evaluate` should have following features,
otherwise there will be a KeyError:
if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feauture_columns` is `None`:
input must contains only real valued `Tensor`.
* if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
* for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feature_columns` is `None`:
input must contains only real valued `Tensor`.
"""
def __init__(self,

View File

@ -47,6 +47,26 @@ class LinearClassifierTest(tf.test.TestCase):
loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss2, loss1)
self.assertLess(loss2, 0.01)
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
def testDisableCenteredBias(self):
"""Tests that we can disable centered bias."""
def input_fn():
return {
'age': tf.constant([1]),
'language': tf.SparseTensor(values=['english'],
indices=[[0, 0]],
shape=[1, 1])
}, tf.constant([[1]])
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
age = tf.contrib.layers.real_valued_column('age')
classifier = tf.contrib.learn.LinearClassifier(
feature_columns=[age, language], enable_centered_bias=False)
classifier.fit(input_fn=input_fn, steps=100)
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
def testTrainOptimizerWithL1Reg(self):
"""Tests l1 regularized model has higher loss."""

View File

@ -81,10 +81,14 @@ class TensorFlowRNNClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
class_weight: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are
supposed to have weight one.
@ -186,18 +190,23 @@ class TensorFlowRNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
continue_training: when continue_training is True, once initialized
model will be continuely trained on every call of fit.
config: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
"""
self.rnn_size = rnn_size
self.cell_type = cell_type

View File

@ -216,7 +216,7 @@ def read_keyed_batch_features(
file_pattern, batch_size, features, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, reader_num_threads=1,
parser_num_threads=1, read_batch_size=1, name=None):
parser_num_threads=1, name=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
@ -245,8 +245,6 @@ def read_keyed_batch_features(
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
name: Name of resulting op.
Returns:
@ -260,7 +258,7 @@ def read_keyed_batch_features(
keys, examples = read_keyed_batch_examples(
file_pattern, batch_size, reader, randomize_input=randomize_input,
num_epochs=num_epochs, queue_capacity=queue_capacity,
num_threads=reader_num_threads, read_batch_size=read_batch_size,
num_threads=reader_num_threads, read_batch_size=batch_size,
name=scope)
if parser_num_threads == 1:
@ -286,8 +284,7 @@ def read_keyed_batch_features(
def read_batch_features(file_pattern, batch_size, features, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, reader_num_threads=1,
parser_num_threads=1, read_batch_size=1,
name=None):
parser_num_threads=1, name=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
@ -316,7 +313,6 @@ def read_batch_features(file_pattern, batch_size, features, reader,
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
name: Name of resulting op.
@ -331,8 +327,7 @@ def read_batch_features(file_pattern, batch_size, features, reader,
file_pattern, batch_size, features, reader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads, read_batch_size=read_batch_size,
name=name)
parser_num_threads=parser_num_threads, name=name)
return features

View File

@ -58,6 +58,7 @@ tf_cc_test(
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
deps = [
"//tensorflow/contrib/linear_optimizer:sdca_op_kernels",
"//tensorflow/contrib/linear_optimizer:sdca_ops_op_lib",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -51,12 +51,14 @@ std::vector<Node*> VarVector(Graph* const g, const int nodes,
return result;
}
Node* Zeros(Graph* const g, const int n) {
Tensor data(DT_FLOAT, TensorShape({n}));
Node* Zeros(Graph* const g, const TensorShape& shape) {
Tensor data(DT_FLOAT, shape);
data.flat<float>().setZero();
return test::graph::Constant(g, data);
}
Node* Zeros(Graph* const g, const int n) { return Zeros(g, TensorShape({n})); }
Node* Ones(Graph* const g, const int n) {
Tensor data(DT_FLOAT, TensorShape({n}));
test::FillFn<float>(&data, [](const int i) { return 1.0f; });
@ -166,28 +168,25 @@ void GetGraphs(const int32 num_examples, const int32 sparse_feature_groups,
Node* const weights = Ones(g, num_examples);
Node* const labels = RandomZeroOrOne(g, num_examples);
Node* const ids = StringIota(g, num_examples);
Node* const example_state_data = Zeros(g, TensorShape({num_examples, 4}));
Node* sdca = nullptr;
TF_CHECK_OK(
NodeBuilder(g->NewName("sdca"), "SdcaSolver")
.Attr("loss_type", "logistic_loss")
.Attr("num_sparse_features", sparse_feature_groups)
.Attr("num_dense_features", dense_feature_groups)
.Attr("l1", 0.0)
.Attr("l2", 1.0)
.Attr("num_inner_iterations", 2)
.Attr("container", strings::StrCat(strings::Hex(random::New64())))
.Attr("solver_uuid", strings::StrCat(strings::Hex(random::New64())))
.Input(sparse_indices)
.Input(sparse_values)
.Input(dense_features)
.Input(weights)
.Input(labels)
.Input(ids)
.Input(sparse_weights)
.Input(dense_weights)
.Finalize(g, &sdca));
TF_CHECK_OK(NodeBuilder(g->NewName("sdca"), "SdcaSolver")
.Attr("loss_type", "logistic_loss")
.Attr("num_sparse_features", sparse_feature_groups)
.Attr("num_dense_features", dense_feature_groups)
.Attr("l1", 0.0)
.Attr("l2", 1.0)
.Attr("num_inner_iterations", 2)
.Input(sparse_indices)
.Input(sparse_values)
.Input(dense_features)
.Input(weights)
.Input(labels)
.Input(sparse_weights)
.Input(dense_weights)
.Input(example_state_data)
.Finalize(g, &sdca));
*train_g = g;
}
@ -202,14 +201,22 @@ void BM_SDCA(const int iters, const int num_examples) {
&train);
testing::StartTiming();
test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
// TODO(sibyl-toe9oF2e): Each all to Run() currently creates a container which
// gets deleted as the context gets deleted. It would be nicer to
// explicitly clean up the container ourselves at this point (after calling
// testing::StopTiming).
}
void BM_SDCA_LARGE_SPARSE(const int iters, const int num_examples) {
testing::StopTiming();
Graph* init = nullptr;
Graph* train = nullptr;
GetGraphs(num_examples, 65 /* sparse feature groups */,
1e6 /* sparse features per group */, 0 /* dense features */, &init,
&train);
testing::StartTiming();
test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
}
} // namespace
BENCHMARK(BM_SDCA)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
BENCHMARK(BM_SDCA_LARGE_SPARSE)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
} // namespace tensorflow

View File

@ -137,11 +137,15 @@ tf_cc_test(
"//tensorflow/contrib/quantization:cc_array_ops",
"//tensorflow/contrib/quantization:cc_math_ops",
"//tensorflow/contrib/quantization:cc_nn_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//third_party/eigen3",
],
)

View File

@ -16,12 +16,15 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_QUANTIZATION_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_QUANTIZATION_UTILS_H_
#define EIGEN_USE_THREADS
// This is a set of functions that standardizes how quantized values are
// interpreted as float numbers.
// All of the current implementations are for reference and have not been
// optimized. They should be implementable using fixed point representations
// to avoid a dependency on floating-point hardware.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
@ -104,6 +107,74 @@ void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
*max_c = c_float_for_one_quant_level * c_highest;
}
// input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct.
// This evaluates to an eigen tensor expression, to be used like:
// auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \
(q2f.range_min + \
(((input_array.template cast<float>() - q2f.lowest_quantized())) * \
q2f.range_scale));
// input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct.
// OutputType is the type of output (e.g. quint8).
// This evaluates to an eigen tensor expression, to be used like:
// auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
#define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
((input_array * f2q.range_scale).round() - \
(f2q.range_min_scaled - f2q.lowest_quantized())) \
.cwiseMax(f2q.lowest_quantized()) \
.cwiseMin(f2q.highest_quantized()) \
.template cast<int32>() \
.template cast<OutputType>()
// For use with DEQUANTIZE_WITH_EIGEN.
template <typename T>
struct QuantizedToFloatStruct {
static constexpr int number_of_bits = sizeof(T) * 8;
static constexpr int64 number_of_steps = static_cast<int64>(1)
<< number_of_bits;
static float lowest_quantized() {
return static_cast<float>(Eigen::NumTraits<T>::lowest());
}
QuantizedToFloatStruct(float range_min, float range_max)
: range_min(range_min),
range_scale((range_max - range_min) / (number_of_steps - 1.0)) {}
const float range_min;
const float range_scale;
};
// For use with QUANTIZE_WITH_EIGEN.
template <typename T>
struct FloatToQuantizedStruct {
static constexpr int number_of_bits = sizeof(T) * 8;
static constexpr int64 number_of_steps = static_cast<int64>(1)
<< number_of_bits;
static constexpr double range_adjust =
(number_of_steps / (number_of_steps - 1.0));
static float lowest_quantized() {
return static_cast<float>(Eigen::NumTraits<T>::lowest());
}
static double lowest_quantized_double() {
return static_cast<double>(Eigen::NumTraits<T>::lowest());
}
static float highest_quantized() {
return static_cast<float>(Eigen::NumTraits<T>::highest());
}
FloatToQuantizedStruct(float range_min, float range_max)
: range_min(range_min),
range_scale((number_of_steps - 1.0) / (range_max - range_min)),
range_min_scaled(round(range_min * range_scale)) {}
const float range_min;
const float range_scale;
const float range_min_scaled;
};
template <class T1, class T2>
inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
float min_new, float max_new) {
@ -130,19 +201,22 @@ inline void RequantizeManyInNewRange<qint32, quint8>(
qint32* input, size_t count, float min_input, float max_input,
float min_output, float max_output, quint8* output) {
// Initially we calculate all the constants we need once, before we go into
// the inner loop.
// the inner loop. If this is updated, also update the Eigen version.
const int fp_shift = 16;
const float input_range = max_input - min_input;
const float output_range = max_output - min_output;
const float recip_output_range = (255.0 / output_range);
const float recip_output_range =
output_range == 0.0 ? 0.0 : (255.0 / output_range);
const int64 recip_output_range_fp =
static_cast<int64>(recip_output_range * (1 << fp_shift));
const int64 range_scale_fp =
static_cast<int64>(255.0 * (1 << fp_shift) * input_range / output_range);
const int64 input_offset_fp =
(min_input * recip_output_range_fp) + (range_scale_fp >> 1);
const int64 output_offset_fp = round((min_output * 255.0) / output_range);
const int64 output_offset_fp =
output_range == 0.0 ? 0.0 : round((min_output * 255.0) / output_range);
const int64 rounding_delta = 1 << (fp_shift - 1);
// Inside this loop we just do minimal adds, multiplies, and shifts, in a way
// that could be easily adapted for a SIMD implementation. It should also be
// possible to perform all the calculations in 32-bit rather than 64, but
@ -162,6 +236,77 @@ inline void RequantizeManyInNewRange<qint32, quint8>(
}
}
template <int shift>
struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const int64 operator()(const int64& a) const {
return a >> shift;
}
};
// See RequantizeManyInNewRange() for a non-eigen reference implementation.
template <class T1, class T2>
inline void RequantizeManyInNewRangeUsingEigen(
const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
float max_input, float min_output, float max_output, Tensor* output) {
auto input_array = input.flat<T1>();
QuantizedToFloatStruct<T1> q2f(min_input, max_input);
auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
FloatToQuantizedStruct<T2> f2q(min_output, max_output);
auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
output->flat<T2>().device(device) = input_requantized;
}
#if 0
// See RequantizeManyInNewRange() for a non-eigen reference implementation.
//
// Because converting 32-bit accumulated results down to eight bit is a common
// case, we have a specialized code path to handle it as efficiently as
// possible using only fixed-point math for the inner loop.
//
// See #ifdefed out test in quantization_utils_test.cc
// (RequantizeManyInNewRange32To8BitUsingEigen).
template <>
inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
float max_input, float min_output, float max_output, Tensor* output) {
// Initially we calculate all the constants we need once, before we go into
// the inner loop. If this is updated, also update the non-Eigen version.
const int fp_shift = 16;
const float input_range = max_input - min_input;
const float output_range = max_output - min_output;
const float recip_output_range =
output_range == 0.0 ? 0.0 : (255.0 / output_range);
const int64 recip_output_range_fp =
static_cast<int64>(recip_output_range * (1 << fp_shift));
const int64 range_scale_fp =
static_cast<int64>(255.0 * (1 << fp_shift) * input_range / output_range);
const int64 input_offset_fp =
(min_input * recip_output_range_fp) + (range_scale_fp >> 1);
const int64 output_offset_fp =
output_range == 0.0 ? 0.0 : round((min_output * 255.0) / output_range);
const int64 rounding_delta = 1 << (fp_shift - 1);
// Inside this eigen expression we just do minimal adds, multiplies, and
// shifts. It should be possible to perform all the calculations in 32-bit
// rather than 64, but that's not been implemented yet.
auto input_array = input.flat<qint32>();
auto fp_value = ((input_array.template cast<int64>() * range_scale_fp)
.unaryExpr(int64_right_shift_op<32>())) +
input_offset_fp;
auto round_intermediate = (fp_value + rounding_delta * fp_value.sign())
.unaryExpr(int64_right_shift_op<fp_shift>());
auto input_requantized = (round_intermediate - output_offset_fp)
.cwiseMax(0LL)
.cwiseMin(255LL)
.template cast<int32>()
.template cast<quint8>();
output->flat<quint8>().device(device) = input_requantized;
}
#endif
// REQUIRES: 'result->NumElements() == input.NumElements()'
template <class T>
void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,

View File

@ -13,17 +13,145 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <limits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/quantization/kernels/quantization_utils.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(QuantizationUtils, FloatToQuantized) {
class QuantizationUtilsTest : public ::testing::Test {
protected:
// If eigen_device is NULL, then the reference implementation is tested.
void TestRequantizeManyInNewRange32To8Bit(
Eigen::ThreadPoolDevice* eigen_device) {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.45f, 1.0f,
-1.0f, 127.0f, 255.0f};
// These are the input and output ranges we'll test.
const size_t ranges_count = 6;
const float ranges[ranges_count][4] = {
{0.0f, 255.0f, 0.0f, 255.0f}, //
{0.0f, 1.0f, 0.0f, 1.0f}, //
{-1.0f, 1.0f, -1.0f, 1.0f}, //
{-1.0f, 1.0f, -255.0f, 255.0f}, //
{3.0f, 3.0f, 0.0f, 255.0f}, // input min == max
{0.0f, 255.0f, 5.0f, 5.0f}, // output min == max
};
for (size_t range_index = 0; range_index < ranges_count; ++range_index) {
const float input_min = ranges[range_index][0];
const float input_max = ranges[range_index][1];
const float output_min = ranges[range_index][2];
const float output_max = ranges[range_index][3];
std::vector<qint32> values_quantized;
std::vector<quint8> expected_values;
for (size_t value_index = 0; value_index < values_count; ++value_index) {
const float value_float = values[value_index];
values_quantized.push_back(
FloatToQuantized<qint32>(value_float, input_min, input_max));
expected_values.push_back(FloatToQuantized<quint8>(
QuantizedToFloat(values_quantized[value_index], input_min,
input_max),
output_min, output_max));
}
Tensor i_tensor =
tensorflow::test::AsTensor(gtl::ArraySlice<qint32>(values_quantized));
Tensor o_tensor(DT_QUINT8, TensorShape{values_count});
auto output_values = o_tensor.flat<quint8>();
if (eigen_device == nullptr) {
auto input_array = i_tensor.flat<qint32>();
RequantizeManyInNewRange(input_array.data(), input_array.size(),
input_min, input_max, output_min, output_max,
output_values.data());
} else {
RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
*eigen_device, i_tensor, input_min, input_max, output_min,
output_max, &o_tensor);
}
for (size_t value_index = 0; value_index < values_count; ++value_index) {
// Here we convert the quantized input value to what we expect
// to get in the output range.
ASSERT_EQ(expected_values[value_index], output_values(value_index))
<< "values_quantized[" << value_index
<< "]=" << values_quantized[value_index] << ", values["
<< value_index << "]=" << values[value_index]
<< ", input_min=" << input_min << ", input_max=" << input_max
<< ", output_min=" << output_min << ", output_max=" << output_max
<< ", value_index=" << value_index;
}
}
}
template <typename InputType, typename OutputType>
void TestRequantizeManyInNewRangeEigenVsNonEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
const size_t ranges_count = 6;
const float ranges[ranges_count][4] = {
{0.0f, 255.0f, 0.0f, 255.0f}, //
{0.0f, 1.0f, 0.0f, 1.0f}, //
{-1.0f, 1.0f, -1.0f, 1.0f}, //
{-1.0f, 1.0f, -255.0f, 255.0f}, //
{3.0f, 3.0f, 0.0f, 255.0f}, // input min == max
{0.0f, 255.0f, 5.0f, 5.0f}, // output min == max
};
// Random values.
for (size_t range_index = 0; range_index < ranges_count; ++range_index) {
const float input_min = ranges[range_index][0];
const float input_max = ranges[range_index][1];
const float output_min = ranges[range_index][2];
const float output_max = ranges[range_index][3];
const int values_count = 10000;
random::PhiloxRandom philox(testing::RandomSeed(), 17);
random::SimplePhilox rnd(&philox);
std::vector<InputType> values_quantized;
for (int i = 0; i < values_count; ++i) {
float v = (rnd.RandFloat() * (input_max - input_min)) + input_min;
values_quantized.push_back(
FloatToQuantized<InputType>(v, input_min, input_max));
}
Tensor i_tensor = tensorflow::test::AsTensor(
gtl::ArraySlice<InputType>(values_quantized));
const auto i_array = i_tensor.flat<InputType>();
Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(),
TensorShape{values_count});
auto output_values_eigen = o_tensor_eigen.flat<OutputType>();
Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(),
TensorShape{values_count});
auto output_values_ref = o_tensor_ref.flat<OutputType>();
RequantizeManyInNewRange(i_array.data(), i_array.size(), input_min,
input_max, output_min, output_max,
output_values_ref.data());
RequantizeManyInNewRangeUsingEigen<InputType, OutputType>(
eigen_device, i_tensor, input_min, input_max, output_min, output_max,
&o_tensor_eigen);
for (int i = 0; i < values_quantized.size(); ++i) {
EXPECT_EQ(output_values_eigen(i), output_values_ref(i)) << i;
}
}
}
};
TEST_F(QuantizationUtilsTest, FloatToQuantized) {
EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 1.0f));
EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 2.0f));
EXPECT_EQ(quint8(128), FloatToQuantized<quint8>(0.5f, 0.0f, 1.0f));
@ -47,7 +175,7 @@ TEST(QuantizationUtils, FloatToQuantized) {
FloatToQuantized<qint32>(128.0f, -128.0f, 128.0f));
}
TEST(QuantizationUtils, QuantizedToFloat) {
TEST_F(QuantizationUtilsTest, QuantizedToFloat) {
EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 1.0f)), 1 / 255.0f);
EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 2.0f)), 1 / 255.0f);
EXPECT_LT(fabsf(0.5f - QuantizedToFloat<quint8>(127, 0.0f, 1.0f)),
@ -78,7 +206,7 @@ TEST(QuantizationUtils, QuantizedToFloat) {
1e-5f);
}
TEST(QuantizationUtils, AvoidBias) {
TEST_F(QuantizationUtilsTest, AvoidBias) {
for (int i = 0; i < 256; ++i) {
const float as_float = QuantizedToFloat<quint8>(i, 0.0f, 2.0f);
const int back_to_int = FloatToQuantized<quint8>(as_float, 0.0f, 2.0f);
@ -86,7 +214,7 @@ TEST(QuantizationUtils, AvoidBias) {
}
}
TEST(QuantizationUtils, RequantizeInNewRange) {
TEST_F(QuantizationUtilsTest, RequantizeInNewRange) {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.5f, 1.0f, -1.0f, 127.0f, 255.0f};
@ -122,7 +250,7 @@ TEST(QuantizationUtils, RequantizeInNewRange) {
}
}
TEST(QuantizationUtils, RequantizeInNewRangeRealData) {
TEST_F(QuantizationUtilsTest, RequantizeInNewRangeRealData) {
const float value_as_float = -0.290169f;
const float input_min = -0.739539f;
const float input_max = 0.641057f;
@ -138,7 +266,7 @@ TEST(QuantizationUtils, RequantizeInNewRangeRealData) {
EXPECT_LT(std::abs(value_as_qint32 - actual_output), 10);
}
TEST(QuantizationUtils, RequantizeInNewRange32To8Bit) {
TEST_F(QuantizationUtilsTest, RequantizeInNewRange32To8Bit) {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f};
@ -174,51 +302,29 @@ TEST(QuantizationUtils, RequantizeInNewRange32To8Bit) {
}
}
TEST(QuantizationUtils, RequantizeManyInNewRange32To8Bit) {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f};
// These are the input and output ranges we'll test.
const size_t ranges_count = 4;
const float ranges[ranges_count][4] = {
{0.0f, 255.0f, 0.0f, 255.0f},
{0.0f, 1.0f, 0.0f, 1.0f},
{-1.0f, 1.0f, -1.0f, 1.0f},
{-1.0f, 1.0f, -255.0f, 255.0f},
};
for (size_t range_index = 0; range_index < ranges_count; ++range_index) {
const float input_min = ranges[range_index][0];
const float input_max = ranges[range_index][1];
const float output_min = ranges[range_index][2];
const float output_max = ranges[range_index][3];
qint32 values_quantized[values_count];
quint8 expected_values[values_count];
for (size_t value_index = 0; value_index < values_count; ++value_index) {
const float value_float = values[value_index];
values_quantized[value_index] =
FloatToQuantized<qint32>(value_float, input_min, input_max);
expected_values[value_index] = FloatToQuantized<quint8>(
QuantizedToFloat(values_quantized[value_index], input_min, input_max),
output_min, output_max);
}
quint8 output_values[values_count];
RequantizeManyInNewRange<qint32, quint8>(values_quantized, values_count,
input_min, input_max, output_min,
output_max, output_values);
for (size_t value_index = 0; value_index < values_count; ++value_index) {
// Here we convert the quantized input value to what we expect
// to get in the output range.
EXPECT_EQ(expected_values[value_index], output_values[value_index])
<< "values_quantized[" << value_index
<< "]=" << values_quantized[value_index] << ", values[" << value_index
<< "]=" << values[value_index] << ", input_min=" << input_min
<< ", input_max=" << input_max << ", output_min=" << output_min
<< ", output_max=" << output_max << ", value_index=" << value_index;
}
}
TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8Bit) {
TestRequantizeManyInNewRange32To8Bit(nullptr /* eigen_device */);
}
TEST(QuantizationUtils, FloatTensorToQuantized) {
#if 0
TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitUsingEigen) {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
TestRequantizeManyInNewRange32To8Bit(&eigen_device);
}
#endif
TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitEigenVsNonEigen) {
TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, quint8>();
}
TEST_F(QuantizationUtilsTest,
RequantizeManyInNewRange32To8BitSignedEigenVsNonEigen) {
TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, qint8>();
}
TEST_F(QuantizationUtilsTest, FloatTensorToQuantized) {
const int input_width = 3;
const int input_height = 3;
const float input_min = 0.0f;
@ -232,7 +338,7 @@ TEST(QuantizationUtils, FloatTensorToQuantized) {
test::ExpectTensorEqual<quint8>(expected, output);
}
TEST(QuantizationUtils, QuantizedTensorToFloat) {
TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) {
const int input_width = 3;
const int input_height = 3;
const float input_min = -128.0f;

View File

@ -15,8 +15,11 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#define EIGEN_USE_THREADS
#include <math.h>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/quantization/kernels/quantization_utils.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -26,6 +29,8 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <class T1, class T2>
class QuantizeDownAndShrinkRangeOp : public OpKernel {
public:
@ -62,10 +67,22 @@ class QuantizeDownAndShrinkRangeOp : public OpKernel {
input_max_float));
const float actual_max_float = QuantizedToFloat(
actual_max_quantized, input_min_float, input_max_float);
#if 0
// This is the reference, non-eigen implementation:
auto output_array = output->flat<T2>();
RequantizeManyInNewRange(input_array.data(), input_array.size(),
input_min_float, input_max_float, actual_min_float,
actual_max_float, output_array.data());
#endif
if (input_array.size() > 0) {
RequantizeManyInNewRangeUsingEigen<T1, T2>(
ctx->eigen_device<CPUDevice>(), input, input_min_float,
input_max_float, actual_min_float, actual_max_float, output);
}
output_min->flat<float>().setConstant(actual_min_float);
output_max->flat<float>().setConstant(actual_max_float);
}

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/quantization/kernels/quantization_utils.h"
#include "tensorflow/core/framework/numeric_op.h"

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

View File

@ -142,9 +142,9 @@ def are_tensors_near(a, b, tolerance):
return True
else:
print("Tensors have {0} different values ({1}%), with mean difference"
" {2} and mean absolute difference {3}").format(
" {2} and mean absolute difference {3}".format(
how_many_different, proportion_different * 100, mean_difference,
mean_abs_difference)
mean_abs_difference))
return False

View File

@ -952,6 +952,7 @@ tf_cc_tests(
":io",
":ops_testutil",
":ops_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

View File

@ -42,6 +42,10 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
const Tensor& boxes,
const Tensor& box_ind,
int* num_boxes) {
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
*num_boxes = 0;
return;
}
// The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
@ -132,9 +136,13 @@ class CropAndResizeOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResize<Device, T>()(context->eigen_device<Device>(),
image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
bool status = functor::CropAndResize<Device, T>()(
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
}
private:
@ -145,11 +153,12 @@ class CropAndResizeOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResize<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -163,7 +172,11 @@ struct CropAndResize<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -217,6 +230,7 @@ struct CropAndResize<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -235,6 +249,7 @@ class CropAndResizeGradImageOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@ -294,9 +309,13 @@ class CropAndResizeGradImageOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResizeBackpropImage<Device, T>()(
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
}
}
};
@ -304,11 +323,12 @@ class CropAndResizeGradImageOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropImage<CPUDevice, T> {
void operator()(const CPUDevice& d,
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image) {
const int batch = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@ -324,7 +344,11 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -370,6 +394,7 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -388,6 +413,7 @@ class CropAndResizeGradBoxesOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@ -441,9 +467,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResizeBackpropBoxes<Device, T>()(
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
}
}
};
@ -451,12 +481,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropBoxes<CPUDevice, T> {
void operator()(const CPUDevice& d,
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -472,7 +503,11 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_ratio =
(crop_height > 1)
@ -547,6 +582,7 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -563,37 +599,25 @@ inline void CheckValidBoxInd<CPUDevice>(
}
}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>);
TF_CALL_half(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@ -613,6 +637,10 @@ template <>
inline void CheckValidBoxInd<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch) {
const int num_boxes = box_ind.dimension(0);
if (num_boxes == 0) {
return;
}
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
@ -657,7 +685,7 @@ inline void CheckValidBoxInd<GPUDevice>(
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL

View File

@ -26,7 +26,7 @@ namespace functor {
template <typename Device, typename T>
struct CropAndResize {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
@ -36,7 +36,7 @@ struct CropAndResize {
template <typename Device, typename T>
struct CropAndResizeBackpropImage {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image);
@ -45,7 +45,7 @@ struct CropAndResizeBackpropImage {
template <typename Device, typename T>
struct CropAndResizeBackpropBoxes {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,

View File

@ -33,27 +33,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename T>
__global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes,
int image_height, int image_width,
int crop_height, int crop_width, int depth,
float extrapolation_value,
float* crops_ptr) {
__global__ void CropAndResizeKernel(
const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
float extrapolation_value, float* crops_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -66,7 +69,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
crops_ptr[out_idx] = extrapolation_value;
return;
continue;
}
const float in_x = (crop_width > 1)
@ -74,7 +77,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
crops_ptr[out_idx] = extrapolation_value;
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -114,22 +117,28 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
template <typename T>
__global__ void CropAndResizeBackpropImageKernel(
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int image_height, int image_width,
int crop_height, int crop_width, int depth, T* grads_image_ptr) {
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
T* grads_image_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -141,14 +150,14 @@ __global__ void CropAndResizeBackpropImageKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
return;
continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -192,23 +201,28 @@ __global__ void CropAndResizeBackpropImageKernel(
template <typename T>
__global__ void CropAndResizeBackpropBoxesKernel(
const int32 nthreads, const float* grads_ptr, const T* image_ptr,
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes,
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
int image_height, int image_width, int crop_height, int crop_width,
int depth, float* grads_boxes_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_ratio =
(crop_height > 1)
@ -226,14 +240,14 @@ __global__ void CropAndResizeBackpropBoxesKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
return;
continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -306,11 +320,12 @@ namespace functor {
template <typename T>
struct CropAndResize<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -320,19 +335,22 @@ struct CropAndResize<GPUDevice, T> {
const int depth = crops.dimension(3);
const int total_count = num_boxes * crop_height * crop_width * depth;
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(
config.virtual_thread_count, image.data(), boxes.data(), box_ind.data(),
num_boxes, image_height, image_width, crop_height, crop_width, depth,
extrapolation_value, crops.data());
if (total_count > 0) {
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(
config.virtual_thread_count, image.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, extrapolation_value, crops.data());
}
return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropImage<GPUDevice, T> {
void operator()(const GPUDevice& d,
bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
@ -351,29 +369,35 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
// Initialize grads_image with all zeros.
total_count = batch * image_height * image_width * depth;
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
total_count, grads_image.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads_image.data());
}
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropImageKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), boxes.data(), box_ind.data(),
num_boxes, image_height, image_width, crop_height, crop_width, depth,
grads_image.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropImageKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, grads_image.data());
}
return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropBoxes<GPUDevice, T> {
void operator()(const GPUDevice& d,
bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -387,18 +411,23 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
// Initialize grads_boxes with all zeros.
total_count = num_boxes * 4;
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
total_count, grads_boxes.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads_boxes.data());
}
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropBoxesKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
box_ind.data(), num_boxes, image_height, image_width, crop_height,
crop_width, depth, grads_boxes.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropBoxesKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, grads_boxes.data());
}
return d.ok();
}
};
@ -407,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
TF_CALL_float(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS

View File

@ -189,6 +189,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
MakeOp(0);
// Input:
// 1, 2
// 3, 4
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({0, 4}), {});
AddInputFromArray<int32>(TensorShape({0}), {});
AddInputFromArray<int32>(TensorShape({2}), {3, 3});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({0, 3, 3, 1}));
// clang-format off
test::FillValues<float>(&expected, {});
// clang-format on
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
@ -201,6 +219,19 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
<< s;
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
AddInputFromArray<int32>(TensorShape({2}), {4, 4});
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
<< s;
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});

View File

@ -16,6 +16,9 @@ limitations under the License.
#include <functional>
#include <memory>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/io_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -24,12 +27,15 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
namespace tensorflow {
@ -366,7 +372,6 @@ TEST_F(SaveOpTest, Simple) {
EXPECT_EQ(200 + i, data[i].imag());
}
}
{
// The 2-d half tensor
TensorShape shape;
@ -652,5 +657,40 @@ TEST_F(SaveOpSlices2Test, TwoSlices) {
}
}
// Benchmark-related code below.
static void BM_LargeTensorWrite(int iters, int num_elements) {
testing::StopTiming();
// 4 * num_elements bytes total , since sizeof(float) == 4.
Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
tensor.flat<float>().setZero();
// Builds the graph.
const string temp_filename =
io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* filename = ops::Const(test::AsScalar<string>(temp_filename), b.opts());
Node* tensor_names =
ops::Const(test::AsTensor<string>({"my_tensor"}), b.opts());
Node* tensors = ops::Const(tensor, b.opts());
ops::Save(filename, tensor_names, {tensors}, b.opts());
// Disables optimizations.
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(tensorflow::OptimizerOptions_Level_L0);
Graph* g = new Graph(OpRegistry::Global());
TF_CHECK_OK(b.ToGraph(g));
VLOG(1) << "Save op's output path: " << temp_filename;
VLOG(1) << "# nodes in Graph: " << g->num_nodes();
testing::StartTiming();
test::Benchmark("cpu", g, &session_options).Run(iters);
}
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
} // namespace
} // namespace tensorflow

View File

@ -47,6 +47,74 @@ Add all input tensors element wise.
inputs: Must all be the same size and shape.
)doc");
namespace {
// Shape inference function for binary operators that broadcast their inputs.
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
const Shape* shape_x = c->input(0);
const Shape* shape_y = c->input(1);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
c->set_output(0, c->CreateUnknownShape());
return Status::OK();
}
const int32 rank_x = c->Rank(shape_x);
const int32 rank_y = c->Rank(shape_y);
const int32 rank_out = std::max(rank_x, rank_y);
// To compute the broadcast dimensions, we zip together shape_x and shape_y
// and
// pad with 1 to make them the same length.
std::vector<const Dimension*> dims;
const Dimension* dim_one = rank_x == rank_y ? nullptr : c->CreateDim(1);
for (int i = 0; i < rank_out; ++i) {
const auto* dim_x = i < (rank_out - rank_x)
? dim_one
: c->Dim(shape_x, i - (rank_out - rank_x));
const auto* dim_y = i < (rank_out - rank_y)
? dim_one
: c->Dim(shape_y, i - (rank_out - rank_y));
if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
// One or both dimensions is unknown.
//
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// TODO(cwhipkey): For shape inference, if we eliminate the shape checks
// in C++ op code, we must still assert that the unknown dim is either 1
// or the same as the known dim.
// - If either dimension is 1, the other dimension is the output.
if (c->Value(dim_x) > 1) {
dims.push_back(dim_x);
} else if (c->Value(dim_y) > 1) {
dims.push_back(dim_y);
} else if (c->Value(dim_x) == 1) {
dims.push_back(dim_y);
} else if (c->Value(dim_y) == 1) {
dims.push_back(dim_x);
} else {
dims.push_back(c->CreateUnknownDim());
}
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
if (c->Value(dim_x) == 1 && dim_y != dim_one) {
// We will broadcast dim_x to dim_y.
dims.push_back(dim_y);
} else {
DCHECK_EQ(c->Value(dim_y), 1);
// We will broadcast dim_y to dim_x.
dims.push_back(dim_x);
}
} else {
const Dimension* dim;
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
dims.push_back(dim);
}
}
c->set_output(0, c->CreateShape(dims));
return Status::OK();
}
} // namespace
// --------------------------------------------------------------------------
REGISTER_OP("BatchMatMul")
@ -373,6 +441,7 @@ REGISTER_OP("Add")
.Attr(
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
"complex128, string}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x + y element-wise.
@ -381,6 +450,7 @@ Returns x + y element-wise.
REGISTER_OP("Sub")
.BINARY_FEWER()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x - y element-wise.
)doc");
@ -388,12 +458,14 @@ Returns x - y element-wise.
REGISTER_OP("Mul")
.BINARY_MORE()
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x * y element-wise.
)doc");
REGISTER_OP("Div")
.BINARY_MORE()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x / y element-wise.
)doc");
@ -401,6 +473,7 @@ Returns x / y element-wise.
REGISTER_OP("SquaredDifference")
.BINARY_FEWER()
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns (x - y)(x - y) element-wise.
)doc");
@ -414,6 +487,7 @@ REGISTER_OP("Maximum")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts.
)doc");
@ -424,6 +498,7 @@ REGISTER_OP("Minimum")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts.
)doc");
@ -433,6 +508,7 @@ REGISTER_OP("Mod")
.Input("y: T")
.Output("z: T")
.Attr("T: {int32, int64, float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns element-wise remainder of division.
)doc");
@ -442,6 +518,7 @@ REGISTER_OP("Pow")
.Input("y: T")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64, complex64, complex128}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Computes the power of one value to another.
@ -460,6 +537,7 @@ REGISTER_OP("Igammac")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the upper regularized incomplete Gamma function `Q(a, x)`.
@ -483,6 +561,7 @@ REGISTER_OP("Igamma")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the lower regularized incomplete Gamma function `Q(a, x)`.
@ -506,6 +585,7 @@ REGISTER_OP("Zeta")
.Input("q: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
@ -521,6 +601,7 @@ REGISTER_OP("Polygamma")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the polygamma function \\(\psi^{(n)}(x)\\).
@ -536,8 +617,12 @@ where \\(\psi(x)\\) is the digamma function.
// Declares cwise binary comparison operations signature: 't, 't -> bool,
// where 't has a natural total order.
#define COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").Attr("T: realnumbertype")
#define COMPARISON() \
Input("x: T") \
.Input("y: T") \
.Output("z: bool") \
.Attr("T: realnumbertype") \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("Less")
.COMPARISON()
@ -567,10 +652,16 @@ Returns the truth value of (x >= y) element-wise.
// --------------------------------------------------------------------------
#define EQUALITY_COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \
"quint8, qint8, qint32, string, bool, complex128}")
#define EQUALITY_COMPARISON() \
Input("x: T") \
.Input("y: T") \
.Output("z: bool") \
.SetIsCommutative() \
.Attr( \
"T: {half, float, double, uint8, int8, int16, int32, int64, " \
"complex64, " \
"quint8, qint8, qint32, string, bool, complex128}") \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("Equal")
.EQUALITY_COMPARISON()
@ -596,8 +687,12 @@ REGISTER_OP("LogicalNot")
Returns the truth value of NOT x element-wise.
)doc");
#define BINARY_LOGICAL() \
Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative()
#define BINARY_LOGICAL() \
Input("x: bool") \
.Input("y: bool") \
.Output("z: bool") \
.SetIsCommutative() \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("LogicalAnd")
.BINARY_LOGICAL()
@ -1271,6 +1366,7 @@ REGISTER_OP("Complex")
.Output("out: Tout")
.Attr("T: {float, double} = DT_FLOAT")
.Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Converts two real numbers to a complex number.

View File

@ -121,4 +121,52 @@ TEST(MathOpsTest, Segment) {
}
}
TEST(MathOpsTest, BroadcastBinaryOps) {
for (const auto* op : {"Add", "Complex",
"Div", "Equal",
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference"}) {
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
INFER_OK(op, "?;[1,2]", "?");
INFER_OK(op, "[?];[1]", "[d0_0]");
INFER_OK(op, "[1];[?]", "[d1_0]");
INFER_OK(op, "[?];[2]", "[d1_0]");
INFER_OK(op, "[2];[?]", "[d0_0]");
INFER_OK(op, "[?];[?]", "[?]");
INFER_OK(op, "[];[?]", "[d1_0]");
INFER_OK(op, "[?];[]", "[d0_0]");
INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
INFER_OK(op, "[];[1]", "[d1_0]");
INFER_OK(op, "[1];[]", "[d0_0]");
INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
INFER_OK(op, "[];[2]", "[d1_0]");
INFER_OK(op, "[1];[2]", "[d1_0]");
INFER_OK(op, "[2];[1]", "[d0_0]");
INFER_OK(op, "[2];[]", "[d0_0]");
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
INFER_OK(op, "[];[0]", "[d1_0]");
INFER_OK(op, "[1];[0]", "[d1_0]");
INFER_OK(op, "[0];[1]", "[d0_0]");
INFER_OK(op, "[0];[]", "[d0_0]");
// Multiple dimension cases (same test cases, switching x and y).
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
}
}
} // end namespace tensorflow

View File

@ -15,7 +15,8 @@ Train and evaluate TensorFlow models.
Abstract BaseEstimator class to train and evaluate TensorFlow models.
Concrete implementation of this class should provide following functions:
Concrete implementation of this class should provide the following functions:
* _get_train_ops
* _get_eval_ops
* _get_predict_ops
@ -65,17 +66,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -108,12 +112,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -126,8 +133,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -141,7 +146,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -296,20 +302,24 @@ Constructs an Estimator instance.
* <b>`model_fn`</b>: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
* <b>`Where`</b>:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) -> (predictions, loss, train_op)`
Where
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
* <b>`model_dir`</b>: Directory to save model parameters, graph and etc.
* <b>`config`</b>: Configuration object.
* <b>`params`</b>: `dict` of hyper parameters that will be passed into `model_fn`.
@ -352,17 +362,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -395,12 +408,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -413,8 +429,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -428,7 +442,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -642,17 +657,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -683,7 +701,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -889,15 +908,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns=None, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#DNNClassifier.__init__}
@ -990,17 +1010,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -1033,12 +1056,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -1051,8 +1077,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -1066,7 +1090,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -1282,15 +1307,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.DNNRegressor.__init__(hidden_units, feature_columns=None, model_dir=None, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#DNNRegressor.__init__}
@ -1381,17 +1407,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -1424,12 +1453,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -1442,8 +1474,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -1457,7 +1487,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -1673,17 +1704,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -1714,7 +1748,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -1938,17 +1973,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -1979,7 +2017,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -2164,10 +2203,15 @@ Initializes a TensorFlowEstimator instance.
Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`clip_gradients`</b>: Clip norm of the gradients to this value to stop
gradient explosion.
* <b>`class_weight`</b>: None or list of n_classes floats. Weight associated with
@ -2178,9 +2222,10 @@ Initializes a TensorFlowEstimator instance.
* <b>`config`</b>: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
* <b>`verbose`</b>: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
- - -
@ -2234,7 +2279,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -2485,15 +2531,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.LinearClassifier.__init__(feature_columns=None, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#LinearClassifier.__init__}
@ -2579,17 +2626,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -2622,12 +2672,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -2640,8 +2693,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -2655,7 +2706,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -2860,15 +2912,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a KeyError:
if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feauture_columns` is `None`:
input must contains only real valued `Tensor`.
* if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
* for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feature_columns` is `None`:
input must contains only real valued `Tensor`.
- - -
#### `tf.contrib.learn.LinearRegressor.__init__(feature_columns=None, model_dir=None, weight_column_name=None, optimizer=None, gradient_clip_norm=None, enable_centered_bias=True, target_dimension=1, config=None)` {#LinearRegressor.__init__}
@ -2953,17 +3006,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -2996,12 +3052,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -3014,8 +3073,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -3029,7 +3086,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -3245,17 +3303,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -3286,7 +3347,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -3510,17 +3572,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -3551,7 +3616,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -3750,10 +3816,15 @@ Initializes a TensorFlowRNNClassifier instance.
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`class_weight`</b>: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are
supposed to have weight one.
@ -3821,7 +3892,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -4065,18 +4137,24 @@ Initializes a TensorFlowRNNRegressor instance.
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`continue_training`</b>: when continue_training is True, once initialized
model will be continuely trained on every call of fit.
* <b>`config`</b>: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
* <b>`verbose`</b>: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
- - -
@ -4137,7 +4215,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:
@ -4407,17 +4486,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -4448,7 +4530,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -29,15 +29,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a KeyError:
if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feauture_columns` is `None`:
input must contains only real valued `Tensor`.
* if `weight_column_name` is not `None`:
key=weight_column_name, value=a `Tensor`
* for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
- if `feature_columns` is `None`:
input must contains only real valued `Tensor`.
- - -
#### `tf.contrib.learn.LinearRegressor.__init__(feature_columns=None, model_dir=None, weight_column_name=None, optimizer=None, gradient_clip_norm=None, enable_centered_bias=True, target_dimension=1, config=None)` {#LinearRegressor.__init__}
@ -122,17 +123,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -165,12 +169,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -183,8 +190,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -198,7 +203,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -46,15 +46,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.LinearClassifier.__init__(feature_columns=None, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#LinearClassifier.__init__}
@ -140,17 +141,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -183,12 +187,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -201,8 +208,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -216,7 +221,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -1,6 +1,7 @@
Abstract BaseEstimator class to train and evaluate TensorFlow models.
Concrete implementation of this class should provide following functions:
Concrete implementation of this class should provide the following functions:
* _get_train_ops
* _get_eval_ops
* _get_predict_ops
@ -50,17 +51,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -93,12 +97,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -111,8 +118,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -126,7 +131,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -11,20 +11,24 @@ Constructs an Estimator instance.
* <b>`model_fn`</b>: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
* <b>`Where`</b>:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) -> (predictions, loss, train_op)`
Where
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or
prediction. See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning.
* <b>`model_dir`</b>: Directory to save model parameters, graph and etc.
* <b>`config`</b>: Configuration object.
* <b>`params`</b>: `dict` of hyper parameters that will be passed into `model_fn`.
@ -67,17 +71,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -110,12 +117,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -128,8 +138,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -143,7 +151,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -40,15 +40,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns=None, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#DNNClassifier.__init__}
@ -141,17 +142,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -184,12 +188,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -202,8 +209,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -217,7 +222,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -19,10 +19,15 @@ Initializes a TensorFlowEstimator instance.
Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`clip_gradients`</b>: Clip norm of the gradients to this value to stop
gradient explosion.
* <b>`class_weight`</b>: None or list of n_classes floats. Weight associated with
@ -33,9 +38,10 @@ Initializes a TensorFlowEstimator instance.
* <b>`config`</b>: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
* <b>`verbose`</b>: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
- - -
@ -89,7 +95,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -32,18 +32,24 @@ Initializes a TensorFlowRNNRegressor instance.
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`continue_training`</b>: when continue_training is True, once initialized
model will be continuely trained on every call of fit.
* <b>`config`</b>: RunConfig object that controls the configurations of the
session, e.g. num_cores, gpu_memory_fraction, etc.
* <b>`verbose`</b>: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
* 0: the algorithm and debug information is muted.
* 1: trainer prints the progress.
* 2: log device placement is printed.
- - -
@ -104,7 +110,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -40,15 +40,16 @@ estimator.predict(x=x)
Input of `fit` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feauture_columns` is `None`, then `input` must contains only real
valued `Tensor`.
* if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
- if `feature_columns` is `None`, then `input` must contains only real
valued `Tensor`.
- - -
#### `tf.contrib.learn.DNNRegressor.__init__(hidden_units, feature_columns=None, model_dir=None, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=True, config=None)` {#DNNRegressor.__init__}
@ -139,17 +140,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -182,12 +186,15 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`input_fn`</b>: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
* <b>`steps`</b>: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
* <b>`batch_size`</b>: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
* <b>`monitors`</b>: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
* <b>`max_steps`</b>: Number of total steps for which to train model. If `None`,
train forever. Two calls to `fit(steps=100)` means 200 training
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
@ -200,8 +207,6 @@ Trains a model given training data `x` predictions and `y` targets.
* <b>`ValueError`</b>: If `x` or `y` are not `None` while `input_fn` is not `None`.
* <b>`ValueError`</b>: If at least one of `x` and `y` is provided, and `input_fn` is
provided.
* <b>`ValueError`</b>: If both `steps` and `max_steps` are not `None`.
@ -215,7 +220,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -58,17 +58,20 @@ for which this evaluation was performed.
provided.
* <b>`steps`</b>: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
* <b>`metrics`</b>: Dict of metric ops to run. If None, the default metric functions
are used; if {}, no metrics are used. If model has one output (i.e.,
returning single predction), keys are `str`, e.g. `'accuracy'` - just a
name of the metric that will show up in the logs / summaries.
Otherwise, keys are tuple of two `str`, e.g. `('accuracy', 'classes')`
- name of the metric and name of `Tensor` in the predictions to run
this metric on. Metric ops should support streaming, e.g., returning
* <b>`metrics`</b>: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluation on
different data sets, such as evaluate on training data vs test data.
* <b>`name`</b>: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
##### Returns:
@ -99,7 +102,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -33,10 +33,15 @@ Initializes a TensorFlowRNNClassifier instance.
used. Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
````python
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
````
* <b>`class_weight`</b>: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are
supposed to have weight one.
@ -104,7 +109,8 @@ Get parameters for this estimator.
* <b>`deep`</b>: boolean, optional
If True, will return the parameters for this estimator and
If `True`, will return the parameters for this estimator and
contained subobjects that are estimators.
##### Returns:

View File

@ -73,11 +73,11 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
Raises:
ValueError: If `params` is empty.
"""
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
raise ValueError("Need at least one param")
if not isinstance(params, list):
params = [params]
with ops.op_scope(params + [ids], name, "embedding_lookup") as name:
if not params:
raise ValueError("Need at least one param")
np = len(params) # Number of partitions
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
if np == 1:

View File

@ -253,54 +253,53 @@ class CropAndResizeOpTest(tf.test.TestCase):
radius = 2 * delta
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
for image_height in range(1, 5):
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
image_height = 4
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
image_tensor = tf.constant(image, shape=image_shape)
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
crops = tf.image.crop_and_resize(
image_tensor,
boxes_tensor,
box_ind_tensor,
tf.constant(crop_size, shape=[2]))
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
image_tensor = tf.constant(image, shape=image_shape)
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
crops = tf.image.crop_and_resize(
image_tensor,
boxes_tensor,
box_ind_tensor,
tf.constant(crop_size, shape=[2]))
err = tf.test.compute_gradient_error(
[image_tensor, boxes_tensor],
[image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
err = tf.test.compute_gradient_error(
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
self.assertLess(err, 2e-3)
self.assertLess(err, 2e-3)
if __name__ == "__main__":

View File

@ -339,7 +339,10 @@ class ExponentialMovingAverage(object):
by the `ExponentialMovingAverage class` to hold the moving average of
`var`.
"""
return var.op.name + "/" + self._name
if var in self._averages:
return self._averages[var].op.name
return ops.get_default_graph().unique_name(
var.op.name + "/" + self._name, mark_as_used=False)
def variables_to_restore(self, moving_avg_variables=None):
"""Returns a map of names to `Variables` to restore.

View File

@ -208,6 +208,37 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
def testAverageVariablesNamesRespectScope(self):
# See discussion on #2740.
with self.test_session():
with tf.variable_scope("scope1"):
v0 = tf.Variable(10.0, name="v0")
v1 = tf.Variable(30.0, name="v1")
# Add a non-trainable variable.
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
with tf.variable_scope("scope2"):
ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
self.assertEqual("scope2/scope1/v0/foo_avg", ema.average_name(v0))
self.assertEqual("scope2/scope1/v1/foo_avg", ema.average_name(v1))
self.assertEqual("scope2/scope1/add/foo_avg", ema.average_name(tensor2))
ema.apply([v0, v1, tensor2])
vars_to_restore = ema.variables_to_restore()
# vars_to_restore should contain the following:
# {scope2/scope1/v0/foo_avg : v0,
# scope2/scope1/v1/foo_avg : v1,
# scope2/scope1/add/foo_avg : add/foo_avg
# scope1/v2 : v2}
self.assertEqual(sorted(vars_to_restore.keys()),
sorted([ema.average_name(v0),
ema.average_name(v1),
ema.average_name(tensor2),
v2.op.name]))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2),
ema.average(tensor2).op.name)
def testSubsetAverageVariablesNames(self):
with self.test_session():
v0 = tf.Variable(10.0, name="v0")

View File

@ -73,5 +73,13 @@ class SlotCreatorTest(tf.test.TestCase):
self.assertEqual(slot.dtype.base_dtype, tf.float32)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
def testCreateSlotFromVariableRespectsScope(self):
# See discussion on #2740.
with self.test_session():
with tf.variable_scope("scope"):
v = tf.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual(slot.op.name, "scope/scope/var/slot")
if __name__ == "__main__":
tf.test.main()

View File

@ -198,7 +198,7 @@ from tensorflow.core.example.feature_pb2 import *
from tensorflow.core.protobuf.saver_pb2 import *
# Utility op. Open Source. TODO(touts): move to nn?
from tensorflow.python.training.learning_rate_decay import exponential_decay
from tensorflow.python.training.learning_rate_decay import *
# Distributed computing support