TensorFlow: upstream changes to git (doc fixes).

Changes:

- Fix typos across several files contributed by Erik Erwitt,
  and Michael R. Berstein

- Fix bug in translate example (fr->en typo) by schuster

- Updates to some documentation (mcoram,shlens,vrv,joshl)

- Fix to Android camera demo app window size detection (andrewharp)

- Fix to support lookup table of high rank tensors (yleon)

- Fix invalid op names for parse_example (dga)

Base CL: 107531031
This commit is contained in:
Vijay Vasudevan 2015-11-10 15:23:01 -08:00
parent 9274f5aa47
commit c61c39614a
25 changed files with 209 additions and 272 deletions

View File

@ -109,9 +109,6 @@ class LookupTableFindOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
const Tensor& input = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input.shape()),
errors::InvalidArgument("Input must be a vector, not ",
input.shape().DebugString()));
const Tensor& default_value = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(default_value.shape()),
@ -119,8 +116,7 @@ class LookupTableFindOp : public OpKernel {
default_value.shape().DebugString()));
Tensor* out;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("output_values", input.shape(), &out));
OP_REQUIRES_OK(ctx, ctx->allocate_output("values", input.shape(), &out));
OP_REQUIRES_OK(ctx, table->Find(input, out, default_value));
}

View File

@ -284,29 +284,28 @@ handle: The handle to a queue.
size: The number of elements in the given queue.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("LookupTableFind")
.Input("table_handle: Ref(string)")
.Input("input_values: Tin")
.Input("keys: Tin")
.Input("default_value: Tout")
.Output("output_values: Tout")
.Output("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.Doc(R"doc(
Maps elements of a tensor into associated values given a lookup table.
Looks up keys in a table, outputs the corresponding values.
If an element of the input_values is not present in the table, the
specified default_value is used.
The tensor `keys` must of the same type as the keys of the table.
The output `values` is of the type of the table values.
The table needs to be initialized and the input and output types correspond
to the table key and value types.
The scalar `default_value` is the value output for keys not present in the
table. It must also be of the same type as the table values.
table_handle: A handle for a lookup table.
input_values: A vector of key values.
default_value: A scalar to return if the input is not found in the table.
output_values: A vector of values associated to the inputs.
table_handle: Handle to the table.
keys: Any shape. Keys to look up.
values: Same shape as `keys`. Values found in the table, or `default_values`
for missing keys.
)doc");
REGISTER_OP("LookupTableSize")
@ -315,8 +314,8 @@ REGISTER_OP("LookupTableSize")
.Doc(R"doc(
Computes the number of elements in the given table.
table_handle: The handle to a lookup table.
size: The number of elements in the given table.
table_handle: Handle to the table.
size: Scalar that contains number of elements in the table.
)doc");
REGISTER_OP("HashTable")
@ -326,18 +325,19 @@ REGISTER_OP("HashTable")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.Doc(R"doc(
Creates and holds an immutable hash table.
Creates a non-initialized hash table.
The key and value types can be specified. After initialization, the table
becomes immutable.
This op creates a hash table, specifying the type of its keys and values.
Before using the table you will have to initialize it. After initialization the
table will be immutable.
table_handle: a handle of a the lookup table.
container: If non-empty, this hash table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this hash table is shared under the given name across
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this table is shared under the given name across
multiple sessions.
key_dtype: the type of the table key.
value_dtype: the type of the table value.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("InitializeTable")
@ -349,9 +349,9 @@ REGISTER_OP("InitializeTable")
.Doc(R"doc(
Table initializer that takes two tensors for keys and values respectively.
table_handle: a handle of the lookup table to be initialized.
keys: a vector of keys of type Tkey.
values: a vector of values of type Tval.
table_handle: Handle to a table which will be initialized.
keys: Keys of type Tkey.
values: Values of type Tval. Same shape as `keys`.
)doc");
} // namespace tensorflow

View File

@ -93,7 +93,7 @@ class Tensor {
/// \brief Slice this tensor along the 1st dimension.
/// I.e., the returned tensor satisifies
/// I.e., the returned tensor satisfies
/// returned[i, ...] == this[dim0_start + i, ...].
/// The returned tensor shares the underlying tensor buffer with this
/// tensor.

View File

@ -25,7 +25,7 @@ bool EventsWriter::Init() {
if (FileHasDisappeared()) {
// Warn user of data loss and let .reset() below do basic cleanup.
if (num_outstanding_events_ > 0) {
LOG(WARNING) << "Re-intialization, attempting to open a new file, "
LOG(WARNING) << "Re-initialization, attempting to open a new file, "
<< num_outstanding_events_ << " events will be lost.";
}
} else {

View File

@ -6,41 +6,46 @@ This folder contains a simple camera-based demo application utilizing Tensorflow
This demo uses a Google Inception model to classify camera frames in real-time,
displaying the top results in an overlay on the camera image. See
assets/imagenet_comp_graph_label_strings.txt for the possible classificiations.
[`assets/imagenet_comp_graph_label_strings.txt`](assets/imagenet_comp_graph_label_strings.txt)
for the possible classifications.
## To build/install/run
As a pre-requisite, Bazel, the Android NDK, and the Android SDK must all be
As a prerequisite, Bazel, the Android NDK, and the Android SDK must all be
installed on your system. The Android build tools may be obtained from:
https://developer.android.com/tools/revisions/build-tools.html
The Android entries in [<workspace_root>/WORKSPACE](../../WORKSPACE) must be
The Android entries in [`<workspace_root>/WORKSPACE`](../../WORKSPACE) must be
uncommented with the paths filled in appropriately depending on where you
installed the NDK and SDK. Otherwise an error such as:
"The external label '//external:android/sdk' is not bound to anything" will
be reported.
To build the APK, run this from your workspace root:
```bash
$ bazel build //tensorflow/examples/android:tensorflow_demo -c opt --copt=-mfpu=neon
```
bazel build //tensorflow/examples/android:tensorflow_demo -c opt --copt=-mfpu=neon
```
Note that "-c opt" is currently required; if not set, an assert (for an
Note that `-c opt` is currently required; if not set, an assert (for an
otherwise non-problematic issue) in Eigen will halt the application during
execution. This issue will be corrected in an upcoming release.
If adb debugging is enabled on your Android 5.0 or later device, you may then
use the following command from your workspace root to install the APK once
built:
'''
adb install -r -g bazel-bin/tensorflow/examples/android/tensorflow_demo_incremental.apk
'''
```bash
$ adb install -r -g bazel-bin/tensorflow/examples/android/tensorflow_demo_incremental.apk
```
Alternatively, a streamlined means of building, installing and running in one
command is:
```
bazel mobile-install //tensorflow/examples/android:tensorflow_demo -c opt --start_app --copt=-mfpu=neon
```bash
$ bazel mobile-install //tensorflow/examples/android:tensorflow_demo -c opt --start_app --copt=-mfpu=neon
```
If camera permission errors are encountered (possible on Android Marshmallow or
above), then the adb install command above should be used instead, as it
automatically grants the required camera permissions with '-g'.
above), then the `adb install` command above should be used instead, as it
automatically grants the required camera permissions with `-g`.

View File

@ -63,6 +63,12 @@ import java.util.concurrent.TimeUnit;
public class CameraConnectionFragment extends Fragment {
private static final Logger LOGGER = new Logger();
/**
* The camera preview size will be chosen to be the smallest frame by pixel size capable of
* containing a DESIRED_SIZE x DESIRED_SIZE square.
*/
private static final int MINIMUM_PREVIEW_SIZE = 320;
private RecognitionScoreView scoreView;
/**
@ -227,8 +233,7 @@ public class CameraConnectionFragment extends Fragment {
// Collect the supported resolutions that are at least as big as the preview Surface
final List<Size> bigEnough = new ArrayList<>();
for (final Size option : choices) {
// TODO(andrewharp): Choose size intelligently.
if (option.getHeight() == 320 && option.getWidth() == 480) {
if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) {
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
bigEnough.add(option);
} else {

View File

@ -178,7 +178,7 @@ This tensor shares other&apos;s underlying storage. Returns `true` iff `other.sh
Slice this tensor along the 1st dimension.
I.e., the returned tensor satisifies returned[i, ...] == this[dim0_start + i, ...]. The returned tensor shares the underlying tensor buffer with this tensor.
I.e., the returned tensor satisfies returned[i, ...] == this[dim0_start + i, ...]. The returned tensor shares the underlying tensor buffer with this tensor.
NOTE: The returned tensor may not satisfies the same alignment requirement as this tensor depending on the shape. The caller must check the returned tensor&apos;s alignment before calling certain methods that have alignment requirement (e.g., ` flat() `, `tensor()`).

View File

@ -1,52 +0,0 @@
# Class `tensorflow::TensorBuffer` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensorbuffer-"></a>
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
* [`tensorflow::TensorBuffer::~TensorBuffer() override`](#tensorflow_TensorBuffer_TensorBuffer)
* [`virtual void* tensorflow::TensorBuffer::data() const =0`](#virtual_void_tensorflow_TensorBuffer_data)
* [`virtual size_t tensorflow::TensorBuffer::size() const =0`](#virtual_size_t_tensorflow_TensorBuffer_size)
* [`virtual TensorBuffer* tensorflow::TensorBuffer::root_buffer()=0`](#virtual_TensorBuffer_tensorflow_TensorBuffer_root_buffer)
* [`virtual void tensorflow::TensorBuffer::FillAllocationDescription(AllocationDescription *proto) const =0`](#virtual_void_tensorflow_TensorBuffer_FillAllocationDescription)
* [`T* tensorflow::TensorBuffer::base() const`](#T_tensorflow_TensorBuffer_base)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
#### `tensorflow::TensorBuffer::~TensorBuffer() override` <a class="md-anchor" id="tensorflow_TensorBuffer_TensorBuffer"></a>
#### `virtual void* tensorflow::TensorBuffer::data() const =0` <a class="md-anchor" id="virtual_void_tensorflow_TensorBuffer_data"></a>
#### `virtual size_t tensorflow::TensorBuffer::size() const =0` <a class="md-anchor" id="virtual_size_t_tensorflow_TensorBuffer_size"></a>
#### `virtual TensorBuffer* tensorflow::TensorBuffer::root_buffer()=0` <a class="md-anchor" id="virtual_TensorBuffer_tensorflow_TensorBuffer_root_buffer"></a>
#### `virtual void tensorflow::TensorBuffer::FillAllocationDescription(AllocationDescription *proto) const =0` <a class="md-anchor" id="virtual_void_tensorflow_TensorBuffer_FillAllocationDescription"></a>
#### `T* tensorflow::TensorBuffer::base() const` <a class="md-anchor" id="T_tensorflow_TensorBuffer_base"></a>

View File

@ -1,45 +0,0 @@
# Class `tensorflow::TensorShapeIter` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensorshapeiter-"></a>
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
* [`tensorflow::TensorShapeIter::TensorShapeIter(const TensorShape *shape, int d)`](#tensorflow_TensorShapeIter_TensorShapeIter)
* [`bool tensorflow::TensorShapeIter::operator==(const TensorShapeIter &rhs)`](#bool_tensorflow_TensorShapeIter_operator_)
* [`bool tensorflow::TensorShapeIter::operator!=(const TensorShapeIter &rhs)`](#bool_tensorflow_TensorShapeIter_operator_)
* [`void tensorflow::TensorShapeIter::operator++()`](#void_tensorflow_TensorShapeIter_operator_)
* [`TensorShapeDim tensorflow::TensorShapeIter::operator*()`](#TensorShapeDim_tensorflow_TensorShapeIter_operator_)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
#### `tensorflow::TensorShapeIter::TensorShapeIter(const TensorShape *shape, int d)` <a class="md-anchor" id="tensorflow_TensorShapeIter_TensorShapeIter"></a>
#### `bool tensorflow::TensorShapeIter::operator==(const TensorShapeIter &rhs)` <a class="md-anchor" id="bool_tensorflow_TensorShapeIter_operator_"></a>
#### `bool tensorflow::TensorShapeIter::operator!=(const TensorShapeIter &rhs)` <a class="md-anchor" id="bool_tensorflow_TensorShapeIter_operator_"></a>
#### `void tensorflow::TensorShapeIter::operator++()` <a class="md-anchor" id="void_tensorflow_TensorShapeIter_operator_"></a>
#### `TensorShapeDim tensorflow::TensorShapeIter::operator*()` <a class="md-anchor" id="TensorShapeDim_tensorflow_TensorShapeIter_operator_"></a>

View File

@ -27,17 +27,15 @@ write the graph to a file.
##Classes <a class="md-anchor" id="AUTOGENERATED-classes"></a>
* [tensorflow::Env](../../api_docs/cc/ClassEnv.md)
* [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
* [tensorflow::RandomAccessFile](../../api_docs/cc/ClassRandomAccessFile.md)
* [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
* [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
* [tensorflow::Session](../../api_docs/cc/ClassSession.md)
* [tensorflow::Status](../../api_docs/cc/ClassStatus.md)
* [tensorflow::Tensor](../../api_docs/cc/ClassTensor.md)
* [tensorflow::TensorBuffer](../../api_docs/cc/ClassTensorBuffer.md)
* [tensorflow::TensorShape](../../api_docs/cc/ClassTensorShape.md)
* [tensorflow::TensorShapeIter](../../api_docs/cc/ClassTensorShapeIter.md)
* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
* [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
##Structs <a class="md-anchor" id="AUTOGENERATED-structs"></a>
@ -50,17 +48,15 @@ write the graph to a file.
<div class='sections-order' style="display: none;">
<!--
<!-- ClassEnv.md -->
<!-- ClassEnvWrapper.md -->
<!-- ClassRandomAccessFile.md -->
<!-- ClassWritableFile.md -->
<!-- ClassEnvWrapper.md -->
<!-- ClassSession.md -->
<!-- ClassStatus.md -->
<!-- ClassTensor.md -->
<!-- ClassTensorBuffer.md -->
<!-- ClassTensorShape.md -->
<!-- ClassTensorShapeIter.md -->
<!-- ClassTensorShapeUtils.md -->
<!-- ClassThread.md -->
<!-- ClassWritableFile.md -->
<!-- StructSessionOptions.md -->
<!-- StructState.md -->
<!-- StructTensorShapeDim.md -->

View File

@ -345,7 +345,7 @@ print sess.run(norm)
print sess.run(norm)
```
Another common use of random values is the intialization of variables. Also see
Another common use of random values is the initialization of variables. Also see
the [Variables How To](../../how_tos/variables/index.md).
```python

View File

@ -854,10 +854,10 @@ for an extensive description of how reusing works. Here is a basic example:
```python
with tf.variable_scope("foo"):
v = get_variable("v", [1]) # v.name == "foo/v:0"
w = get_variable("w", [1]) # w.name == "foo/w:0"
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
with tf.variable_scope("foo", reuse=True)
v1 = get_variable("v") # The same as v above.
v1 = tf.get_variable("v") # The same as v above.
```
If initializer is `None` (the default), the default initializer passed in
@ -919,7 +919,7 @@ Basic example of sharing a variable:
```python
with tf.variable_scope("foo"):
v = get_variable("v", [1])
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True):
v1 = tf.get_variable("v", [1])
assert v1 == v
@ -929,7 +929,7 @@ Sharing a variable by capturing a scope and setting reuse:
```python
with tf.variable_scope("foo") as scope.
v = get_variable("v", [1])
v = tf.get_variable("v", [1])
scope.reuse_variables()
v1 = tf.get_variable("v", [1])
assert v1 == v
@ -940,7 +940,7 @@ getting an existing variable in a non-reusing scope.
```python
with tf.variable_scope("foo") as scope.
v = get_variable("v", [1])
v = tf.get_variable("v", [1])
v1 = tf.get_variable("v", [1])
# Raises ValueError("... v already exists ...").
```
@ -950,7 +950,7 @@ does not exist in reuse mode.
```python
with tf.variable_scope("foo", reuse=True):
v = get_variable("v", [1])
v = tf.get_variable("v", [1])
# Raises ValueError("... v does not exists ...").
```

View File

@ -1664,10 +1664,12 @@ Adds a `Summary` protocol buffer to the event file.
This method wraps the provided summary in an `Event` procotol buffer
and adds it to the event file.
You can pass the output of any summary op, as-is, to this function. You
can also pass a `Summary` procotol buffer that you manufacture with your
own data. This is commonly done to report evaluation results in event
files.
You can pass the result of evaluating any summary op, using
[`Session.run()`](client.md#Session.run] or
[`Tensor.eval()`](framework.md#Tensor.eval), to this
function. Alternatively, you can pass a `tf.Summary` protocol
buffer that you populate with your own data. The latter is
commonly done to report evaluation results in event files.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>

View File

@ -222,7 +222,7 @@ From the root of your source tree, run:
``` bash
$ ./configure
Do you wish to bulid TensorFlow with GPU support? [y/n] y
Do you wish to build TensorFlow with GPU support? [y/n] y
GPU support will be enabled for TensorFlow
Please specify the location where CUDA 7.0 toolkit is installed. Refer to

View File

@ -83,7 +83,9 @@ You're now all set to visualize this data using TensorBoard.
To run TensorBoard, use the command
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
```bash
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
```
where `logdir` points to the directory where the `SummaryWriter` serialized its
data. If this `logdir` directory contains subdirectories which contain
@ -91,9 +93,12 @@ serialized data from separate runs, then TensorBoard will visualize the data
from all of those runs. Once TensorBoard is running, navigate your web browser
to `localhost:6006` to view the TensorBoard.
If you have pip installed TensorBoard, you can use the simpler command
If you have pip installed TensorFlow, `tensorboard` is installed into
the system path, so you can use the simpler command
tensorboard --logdir=/path/to/log-directory
```bash
tensorboard --logdir=/path/to/log-directory
```
When looking at TensorBoard, you will see the navigation tabs in the top right
corner. Each tab represents a set of serialized data that can be visualized.

View File

@ -93,7 +93,7 @@ weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")
w_twice = tf.Variable(weights.initialized_value() * 2.0, name="w_twice")
```
### Custom Initialization <a class="md-anchor" id="AUTOGENERATED-custom-initialization"></a>

View File

@ -178,8 +178,9 @@ building operations for training a model.
> **EXERCISE:** The model architecture in `inference()` differs slightly from
the CIFAR-10 model specified in
[cuda-convnet](https://code.google.com/p/cuda-convnet/). In particular, the top
layers are locally connected and not fully connected. Try editing the
architecture to exactly replicate that fully connected model.
layers of Alex's original model are locally connected and not fully connected.
Try editing the architecture to exactly reproduce the locally connected
architecture in the top layer.
### Model Training <a class="md-anchor" id="model-training"></a>
@ -224,7 +225,7 @@ the script `cifar10_train.py`.
python cifar10_train.py
```
**NOTE:** The first time you run any target in the CIFAR-10 tutorial,
> **NOTE:** The first time you run any target in the CIFAR-10 tutorial,
the CIFAR-10 dataset is automatically downloaded. The data set is ~160MB
so you may want to grab a quick cup of coffee for your first run.
@ -256,8 +257,8 @@ obtained on a Tesla K40c. If you are running on a CPU, expect slower performance
> **EXERCISE:** When experimenting, it is sometimes annoying that the first
training step can take so long. Try decreasing the number of images initially
that initially fill up the queue. Search for `NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN`
training step can take so long. Try decreasing the number of images that
initially fill up the queue. Search for `NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN`
in `cifar10.py`.
`cifar10_train.py` periodically [saves](../../api_docs/python/state_ops.md#Saver)
@ -303,7 +304,7 @@ for this purpose.
## Evaluating a Model <a class="md-anchor" id="evaluating-a-model"></a>
Let us now evaluate how well the trained model performs on a hold-out data set.
the model is evaluated by the script `cifar10_eval.py`. It constructs the model
The model is evaluated by the script `cifar10_eval.py`. It constructs the model
with the `inference()` function and uses all 10,000 images in the evaluation set
of CIFAR-10. It calculates the *precision at 1:* how often the top prediction
matches the true label of the image.
@ -415,9 +416,8 @@ See how-to on [Sharing Variables](../../how_tos/variable_scope/index.md).
### Launching and Training the Model on Multiple GPU cards <a class="md-anchor" id="AUTOGENERATED-launching-and-training-the-model-on-multiple-gpu-cards"></a>
If you have several GPU cards installed on your machine you can use them to
train the model faster with the `cifar10_multi_gpu_train.py` script. It is a
variation of the training script that parallelizes the model across multiple GPU
cards.
train the model faster with the `cifar10_multi_gpu_train.py` script. This
version of the training script parallelizes the model across multiple GPU cards.
```shell
python cifar10_multi_gpu_train.py --num_gpus=2
@ -451,7 +451,8 @@ completed the CIFAR-10 tutorial.
If you are now interested in developing and training your own image
classification system, we recommend forking this tutorial and replacing
components to build address your image classification problem.
components to address your image classification problem.
> **EXERCISE:** Download the
[Street View House Numbers (SVHN)](http://ufldl.stanford.edu/housenumbers/) data set.

View File

@ -310,7 +310,7 @@ cross_entropy = -tf.reduce_sum(y_*tf.log(y))
```
First, `tf.log` computes the logarithm of each element of `y`. Next, we multiply
each element of `y_` with the corresponding element of `tf.log(y_)`. Finally,
each element of `y_` with the corresponding element of `tf.log(y)`. Finally,
`tf.reduce_sum` adds all the elements of the tensor. (Note that this isn't
just the cross-entropy of the truth with a single prediction, but the sum of the
cross-entropies for all 100 images we looked at. How well we are doing on 100

View File

@ -311,7 +311,7 @@ translator, end-to-end. Run it and see how the model performs for yourself.
While it has reasonable quality, the default parameters will not give you
the best translation model. Here are a few things you can improve.
First of all, we use a very promitive tokenizer, the `basic_tokenizer` function
First of all, we use a very primitive tokenizer, the `basic_tokenizer` function
in `data_utils`. A better tokenizer can be found on the
[WMT'15 Website](http://www.statmt.org/wmt15/translation-task.html).
Using that tokenizer, and a larger vocabulary, should improve your translations.

View File

@ -254,7 +254,7 @@ def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size):
fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path)
data_to_token_ids(train_path + ".en", fr_train_ids_path, fr_vocab_path)
data_to_token_ids(train_path + ".en", en_train_ids_path, en_vocab_path)
# Create token ids for the development data.
fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)

View File

@ -26,6 +26,25 @@ class HashTableOpTest(tf.test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableFindHighRank(self):
with self.test_session():
shared_name = ''
default_val = -1
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
# Initialize with keys and values tensors.
keys = tf.constant(['brain', 'salad', 'surgery'])
values = tf.constant([0, 1, 2], tf.int64)
init = table.initialize_from(keys, values)
init.run()
self.assertAllEqual(3, table.size().eval())
input_string = tf.constant([['brain', 'salad'], ['tank', 'tarkus']])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
with self.test_session():
shared_name = ''

View File

@ -85,13 +85,14 @@ class ParseExampleTest(tf.test.TestCase):
sess.run(result)
def testEmptySerializedWithAllDefaults(self):
dense_keys = ["a", "b", "c"]
cname = "c:has_a_tricky_name"
dense_keys = ["a", "b", cname]
dense_shapes = [(1, 3), (3, 3), (2,)]
dense_types = [tf.int64, tf.string, tf.float32]
dense_defaults = {
"a": [0, 42, 0],
"b": np.random.rand(3, 3).astype(np.str),
"c": np.random.rand(2).astype(np.float32),
cname: np.random.rand(2).astype(np.float32),
}
expected_st_a = ( # indices, values, shape
@ -103,7 +104,7 @@ class ParseExampleTest(tf.test.TestCase):
"st_a": expected_st_a,
"a": np.array(2 * [[dense_defaults["a"]]]),
"b": np.array(2 * [dense_defaults["b"]]),
"c": np.array(2 * [dense_defaults["c"]]),
cname: np.array(2 * [dense_defaults[cname]]),
}
self._test(
@ -210,14 +211,15 @@ class ParseExampleTest(tf.test.TestCase):
}, expected_output)
def testSerializedContainingDense(self):
bname = "b*has+a:tricky_name"
original = [
example(features=features({
"a": float_feature([1, 1]),
"b": bytes_feature(["b0_str"]),
bname: bytes_feature(["b0_str"]),
})),
example(features=features({
"a": float_feature([-1, -1]),
"b": bytes_feature(["b1"]),
bname: bytes_feature(["b1"]),
}))
]
@ -227,14 +229,14 @@ class ParseExampleTest(tf.test.TestCase):
expected_output = {
"a": np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
"b": np.array(["b0_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1),
bname: np.array(["b0_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
self._test(
{
"serialized": tf.convert_to_tensor(serialized),
"dense_keys": ["a", "b"],
"dense_keys": ["a", bname],
"dense_types": [tf.float32, tf.string],
"dense_shapes": dense_shapes,
}, expected_output)

View File

@ -58,7 +58,7 @@ print sess.run(norm)
print sess.run(norm)
```
Another common use of random values is the intialization of variables. Also see
Another common use of random values is the initialization of variables. Also see
the [Variables How To](../../how_tos/variables/index.md).
```python

View File

@ -412,10 +412,9 @@ class LookupTableBase(object):
"""Construct a table object from a table reference.
Args:
key_dtype: The key data type of the table.
value_dtype: The kvalue data type of the table.
default_value: The scalar tensor to be used when a key is not present in
the table.
key_dtype: The table key type.
value_dtype: The table value type.
default_value: The value to use if a key is missing in the table.
table_ref: The table reference, i.e. the output of the lookup table ops.
"""
self._key_dtype = types.as_dtype(key_dtype)
@ -434,12 +433,12 @@ class LookupTableBase(object):
@property
def key_dtype(self):
"""The key dtype supported by the table."""
"""The table key dtype."""
return self._key_dtype
@property
def value_dtype(self):
"""The value dtype supported by the table."""
"""The table value dtype."""
return self._value_dtype
@property
@ -466,20 +465,19 @@ class LookupTableBase(object):
return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Returns the values for the given 'keys' tensor.
"""Looks up `keys` in a table, outputs the corresponding values.
If an element on the key tensor is not found in the table, the default_value
is used.
The `default_value` is use for keys not present in the table.
Args:
keys: The tensor for the keys.
keys: Keys to look up.
name: Optional name for the op.
Returns:
The operation that looks up the keys.
Raises:
TypeError: when 'keys' or 'default_value' doesn't match the table data
TypeError: when `keys` or `default_value` doesn't match the table data
types.
"""
if name is None:
@ -493,7 +491,7 @@ class LookupTableBase(object):
self._table_ref, keys, self._default_value, name=name)
def initialize_from(self, keys, values, name=None):
"""Initialize the lookup table with the provided keys and values tensors.
"""Initialize the table with the provided keys and values tensors.
Construct an initializer object from keys and value tensors.
@ -503,10 +501,10 @@ class LookupTableBase(object):
name: Optional name for the op.
Returns:
The operation that initializes a lookup table.
The operation that initializes the table.
Raises:
TypeError: when the 'keys' and 'values' data type do not match the table
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
if name is None:
@ -545,22 +543,23 @@ class HashTable(LookupTableBase):
def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
name="hash_table"):
"""Create a generic hash table.
"""Creates a non-initialized hash table.
A table holds a key-value pairs. The key and value types are
described by key_dtype and value_dtype respectively.
This op creates a hash table, specifying the type of its keys and values.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
Args:
key_dtype: The key data type of the table.
value_dtype: The kvalue data type of the table.
default_value: The scalar tensor to be used when a key is not present in
the table.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
default_value: The scalar tensor to be used when a key is missing in the
table.
shared_name: Optional. If non-empty, this table will be shared under
the given name across multiple sessions.
name: Optional name for the hash table op.
Returns:
A table object that can be used to lookup data.
A `HashTable` object.
"""
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name, key_dtype=key_dtype,

View File

@ -1,4 +1,5 @@
"""Parsing Ops."""
import re
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -166,9 +167,10 @@ def parse_example(serialized,
```
Args:
serialized: A list of strings, a batch of binary serialized `Example`
protos.
names: A list of strings, the names of the serialized protos.
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos.
sparse_keys: A list of string keys in the examples' features.
The results for these keys will be returned as `SparseTensor` objects.
sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
@ -192,67 +194,69 @@ def parse_example(serialized,
ValueError: If sparse and dense key sets intersect, or input lengths do not
match up.
"""
names = [] if names is None else names
dense_defaults = {} if dense_defaults is None else dense_defaults
sparse_keys = [] if sparse_keys is None else sparse_keys
sparse_types = [] if sparse_types is None else sparse_types
dense_keys = [] if dense_keys is None else dense_keys
dense_types = [] if dense_types is None else dense_types
dense_shapes = [
[]] * len(dense_keys) if dense_shapes is None else dense_shapes
with ops.op_scope([serialized, names], name, "parse_example"):
names = [] if names is None else names
dense_defaults = {} if dense_defaults is None else dense_defaults
sparse_keys = [] if sparse_keys is None else sparse_keys
sparse_types = [] if sparse_types is None else sparse_types
dense_keys = [] if dense_keys is None else dense_keys
dense_types = [] if dense_types is None else dense_types
dense_shapes = [
[]] * len(dense_keys) if dense_shapes is None else dense_shapes
num_dense = len(dense_keys)
num_sparse = len(sparse_keys)
num_dense = len(dense_keys)
num_sparse = len(sparse_keys)
if len(dense_shapes) != num_dense:
raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
% (len(dense_shapes), num_dense))
if len(dense_types) != num_dense:
raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
% (len(dense_types), num_dense))
if len(sparse_types) != num_sparse:
raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
% (len(sparse_types), num_sparse))
if num_dense + num_sparse == 0:
raise ValueError("Must provide at least one sparse key or dense key")
if not set(dense_keys).isdisjoint(set(sparse_keys)):
raise ValueError(
"Dense and sparse keys must not intersect; intersection: %s" %
set(dense_keys).intersection(set(sparse_keys)))
if len(dense_shapes) != num_dense:
raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
% (len(dense_shapes), num_dense))
if len(dense_types) != num_dense:
raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
% (len(dense_types), num_dense))
if len(sparse_types) != num_sparse:
raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
% (len(sparse_types), num_sparse))
if num_dense + num_sparse == 0:
raise ValueError("Must provide at least one sparse key or dense key")
if not set(dense_keys).isdisjoint(set(sparse_keys)):
raise ValueError(
"Dense and sparse keys must not intersect; intersection: %s" %
set(dense_keys).intersection(set(sparse_keys)))
dense_defaults_vec = []
for i, key in enumerate(dense_keys):
default_value = dense_defaults.get(key)
if default_value is None:
default_value = constant_op.constant([], dtype=dense_types[i])
elif not isinstance(default_value, ops.Tensor):
default_value = ops.convert_to_tensor(
default_value, dtype=dense_types[i], name=key)
default_value = array_ops.reshape(default_value, dense_shapes[i])
dense_defaults_vec = []
for i, key in enumerate(dense_keys):
default_value = dense_defaults.get(key)
if default_value is None:
default_value = constant_op.constant([], dtype=dense_types[i])
elif not isinstance(default_value, ops.Tensor):
key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dense_types[i], name=key_name)
default_value = array_ops.reshape(default_value, dense_shapes[i])
dense_defaults_vec.append(default_value)
dense_defaults_vec.append(default_value)
dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
if isinstance(shape, (list, tuple)) else shape
for shape in dense_shapes]
dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
if isinstance(shape, (list, tuple)) else shape
for shape in dense_shapes]
outputs = gen_parsing_ops._parse_example(
serialized=serialized,
names=names,
dense_defaults=dense_defaults_vec,
sparse_keys=sparse_keys,
sparse_types=sparse_types,
dense_keys=dense_keys,
dense_shapes=dense_shapes,
name=name)
outputs = gen_parsing_ops._parse_example(
serialized=serialized,
names=names,
dense_defaults=dense_defaults_vec,
sparse_keys=sparse_keys,
sparse_types=sparse_types,
dense_keys=dense_keys,
dense_shapes=dense_shapes,
name=name)
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
sparse_tensors = [ops.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(sparse_indices, sparse_values, sparse_shapes)]
sparse_tensors = [ops.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(sparse_indices, sparse_values, sparse_shapes)]
return dict(
zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
return dict(
zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
def parse_single_example(serialized, # pylint: disable=invalid-name
@ -280,9 +284,9 @@ def parse_single_example(serialized, # pylint: disable=invalid-name
See also `parse_example`.
Args:
serialized: A scalar string, a single serialized Example.
serialized: A scalar string Tensor, a single serialized Example.
See parse_example documentation for more details.
names: (Optional) A scalar string, the associated name.
names: (Optional) A scalar string Tensor, the associated name.
See parse_example documentation for more details.
sparse_keys: See parse_example documentation for more details.
sparse_types: See parse_example documentation for more details.
@ -298,7 +302,7 @@ def parse_single_example(serialized, # pylint: disable=invalid-name
Raises:
ValueError: if "scalar" or "names" have known shapes, and are not scalars.
"""
with ops.op_scope([serialized], name, "parse_single_example"):
with ops.op_scope([serialized, names], name, "parse_single_example"):
serialized = ops.convert_to_tensor(serialized)
serialized_shape = serialized.get_shape()
if serialized_shape.ndims is not None: