Branch 168917534 (#13077)

* Use HLO name, rather than pointer address, for profile counter name.

This removes a source of nondeterminism in IR generation.

PiperOrigin-RevId: 168779489

* Eager gradient tape doesn't keep tensors alive.

PiperOrigin-RevId: 168782341

* Add missing back-quote

PiperOrigin-RevId: 168785422

* Add in a comment that I forgot to add to a previous commit; NFC.

PiperOrigin-RevId: 168786760

* Update ops-related pbtxt files.

PiperOrigin-RevId: 168787665

* Go: Update generated wrapper functions for TensorFlow ops.

PiperOrigin-RevId: 168788051

* Fix typo "comptuation" (computation)

PiperOrigin-RevId: 168799777

* Fix a bug in export GTFlow model to shared format with sparse float split

PiperOrigin-RevId: 168802503

* Add signature def utility functions for inspection of input and output types and shapes.

PiperOrigin-RevId: 168820997

* Apply const qualifiers whenever appropriate.

PiperOrigin-RevId: 168824461

* TFE: Clearer error message when enable_eager_execution is called more than once

PiperOrigin-RevId: 168834147

* [tf.contrib.data] Add colocation constraints between Iterator and Datasets.

This restores the colocation behavior that was present when Dataset
objects were passed as DT_RESOURCE tensors, and avoids the (currently
not supported) case where TensorFlow may attempt to split the dataset
pipeline across devices.

PiperOrigin-RevId: 168841061

* Optimize C++ kernels for the matrix_band_part op, which is used in various ops operating on triangular or banded matrices:
 * Add benchmark for matrix_band_part.
 * Implement simple optimized CUDA kernel instead of calling Eigen generator.
 * Parallelize CPU kernel for matrix_band_part.
 * Support on-the-fly transposition in the underlying functors (to be used for future QR op in followup).

Benchmarks:

First column is of the form {device}_{shape}_{num_lower,num_upper}

Test case                        Before       After    Speedup
cpu_(10,16,16)_(-1,-1)          5.6505e-05  6.2108e-05  -9.92%
cpu_(10,16,16)_(-1,0)           0.00010848  0.00010908  -0.55%
cpu_(10,16,16)_(0,-1)            0.0001055  0.00011396  -8.02%
cpu_(10,16,16)_(2,2)              0.000108  0.00011706  -8.39%
cpu_(10,101,101)_(-1,-1)        0.00013697  6.0558e-05 +55.79%
cpu_(10,101,101)_(-1,0)         0.00054002  0.00017703 +67.22%
cpu_(10,101,101)_(0,-1)         0.00051188  0.00017607 +65.60%
cpu_(10,101,101)_(2,2)          0.00050449  0.00016904 +66.49%
cpu_(10,256,256)_(-1,-1)        0.00032043  5.6028e-05 +82.51%
cpu_(10,256,256)_(-1,0)           0.001335   0.0004015 +69.93%
cpu_(10,256,256)_(0,-1)          0.0013521  0.00038862 +71.26%
cpu_(10,256,256)_(2,2)            0.001269  0.00039959 +68.51%
cpu_(10,1000,1000)_(-1,-1)       0.0090729  6.3419e-05 +99.30%
cpu_(10,1000,1000)_(-1,0)          0.01712   0.0047594 +72.20%
cpu_(10,1000,1000)_(0,-1)         0.016647   0.0046474 +72.08%
cpu_(10,1000,1000)_(2,2)          0.012737   0.0041161 +67.68%
cpu_(10,1024,1024)_(-1,-1)       0.0093709  5.8889e-05 +99.37%
cpu_(10,1024,1024)_(-1,0)         0.017075   0.0051999 +69.55%
cpu_(10,1024,1024)_(0,-1)         0.016867    0.004617 +72.63%
cpu_(10,1024,1024)_(2,2)          0.013191    0.003759 +71.50%
cpu_(10,2048,2048)_(-1,-1)        0.028427  6.2466e-05 +99.78%
cpu_(10,2048,2048)_(-1,0)         0.048134    0.017642 +63.35%
cpu_(10,2048,2048)_(0,-1)         0.048773    0.017558 +64.00%
cpu_(10,2048,2048)_(2,2)          0.036153    0.015452 +57.26%
cpu_(10,10,4,4)_(-1,-1)         5.8055e-05  5.8055e-05  +0.00%
cpu_(10,10,4,4)_(-1,0)          0.00015557   0.0001564  -0.54%
cpu_(10,10,4,4)_(0,-1)          0.00015855  0.00015199  +4.14%
cpu_(10,10,4,4)_(2,2)           0.00016379  0.00018096 -10.48%
cpu_(10,10,10,10)_(-1,-1)       6.0558e-05  6.0558e-05  +0.00%
cpu_(10,10,10,10)_(-1,0)          0.000368  0.00038695  -5.15%
cpu_(10,10,10,10)_(0,-1)        0.00036263  0.00038612  -6.48%
cpu_(10,10,10,10)_(2,2)         0.00038648  0.00042963 -11.17%
cpu_(10,10,16,16)_(-1,-1)       6.9022e-05  5.7578e-05 +16.58%
cpu_(10,10,16,16)_(-1,0)         0.0005815   0.0001874 +67.77%
cpu_(10,10,16,16)_(0,-1)        0.00059354   0.0001924 +67.58%
cpu_(10,10,16,16)_(2,2)         0.00062239  0.00019097 +69.32%
cpu_(10,10,101,101)_(-1,-1)     0.00014806  6.2823e-05 +57.57%
cpu_(10,10,101,101)_(-1,0)       0.0039785  0.00078249 +80.33%
cpu_(10,10,101,101)_(0,-1)       0.0040585  0.00076556 +81.14%
cpu_(10,10,101,101)_(2,2)        0.0039514  0.00077307 +80.44%
cpu_(10,10,256,256)_(-1,-1)      0.0026824  6.0558e-05 +97.74%
cpu_(10,10,256,256)_(-1,0)        0.017269   0.0031619 +81.69%
cpu_(10,10,256,256)_(0,-1)        0.020287   0.0030774 +84.83%
cpu_(10,10,256,256)_(2,2)         0.011919   0.0026599 +77.68%
cpu_(10,10,1000,1000)_(-1,-1)     0.065783  5.6982e-05 +99.91%
cpu_(10,10,1000,1000)_(-1,0)        0.1361    0.054533 +59.93%
cpu_(10,10,1000,1000)_(0,-1)        0.1397    0.053405 +61.77%
cpu_(10,10,1000,1000)_(2,2)        0.10173    0.048561 +52.26%
cpu_(10,10,1024,1024)_(-1,-1)     0.066231  7.5579e-05 +99.89%
cpu_(10,10,1024,1024)_(-1,0)       0.13615    0.059931 +55.98%
cpu_(10,10,1024,1024)_(0,-1)       0.13745    0.064931 +52.76%
cpu_(10,10,1024,1024)_(2,2)        0.10493    0.054258 +48.29%
cpu_(10,10,2048,2048)_(-1,-1)      0.23487  6.6042e-05 +99.97%
cpu_(10,10,2048,2048)_(-1,0)       0.41014     0.24283 +40.79%
cpu_(10,10,2048,2048)_(0,-1)       0.43621     0.26393 +39.49%
cpu_(10,10,2048,2048)_(2,2)        0.29919     0.22302 +25.46%

gpu_(10,16,16)_(-1,-1)          0.00010753  0.00010753  +0.00%
gpu_(10,16,16)_(-1,0)           0.00011253  0.00012445 -10.59%
gpu_(10,16,16)_(0,-1)           0.00012493  0.00013399  -7.25%
gpu_(10,16,16)_(2,2)              0.000108  0.00011754  -8.83%
gpu_(10,101,101)_(-1,-1)        0.00011849  8.7976e-05 +25.75%
gpu_(10,101,101)_(-1,0)         0.00012743  0.00012243  +3.93%
gpu_(10,101,101)_(0,-1)         0.00012958  0.00012362  +4.60%
gpu_(10,101,101)_(2,2)          0.00011504  0.00011504  +0.00%
gpu_(10,256,256)_(-1,-1)        0.00013447  9.7513e-05 +27.48%
gpu_(10,256,256)_(-1,0)         0.00018752  0.00014746 +21.36%
gpu_(10,256,256)_(0,-1)         0.00017798  0.00016904  +5.02%
gpu_(10,256,256)_(2,2)           0.0001514  0.00013697  +9.53%
gpu_(10,1000,1000)_(-1,-1)       0.0005095  9.8586e-05 +80.65%
gpu_(10,1000,1000)_(-1,0)       0.00088501  0.00056589 +36.06%
gpu_(10,1000,1000)_(0,-1)       0.00090456  0.00055242 +38.93%
gpu_(10,1000,1000)_(2,2)        0.00080955  0.00049639 +38.68%
gpu_(10,1024,1024)_(-1,-1)      0.00050902  9.7036e-05 +80.94%
gpu_(10,1024,1024)_(-1,0)       0.00098789  0.00058246 +41.04%
gpu_(10,1024,1024)_(0,-1)            0.001  0.00059545 +40.46%
gpu_(10,1024,1024)_(2,2)        0.00082254  0.00049961 +39.26%
gpu_(10,2048,2048)_(-1,-1)        0.001495  9.8944e-05 +93.38%
gpu_(10,2048,2048)_(-1,0)         0.003535   0.0017736 +49.83%
gpu_(10,2048,2048)_(0,-1)        0.0034965   0.0017921 +48.75%
gpu_(10,2048,2048)_(2,2)         0.0027704   0.0015399 +44.41%
gpu_(10,10,4,4)_(-1,-1)         0.00011086  9.1076e-05 +17.85%
gpu_(10,10,4,4)_(-1,0)           0.0001235  0.00013411  -8.59%
gpu_(10,10,4,4)_(0,-1)          0.00011849   0.0001204  -1.61%
gpu_(10,10,4,4)_(2,2)           0.00010896  0.00013256 -21.66%
gpu_(10,10,10,10)_(-1,-1)       0.00010657  9.5844e-05 +10.07%
gpu_(10,10,10,10)_(-1,0)        0.00011754  0.00013602 -15.72%
gpu_(10,10,10,10)_(0,-1)        0.00011909  0.00012004  -0.80%
gpu_(10,10,10,10)_(2,2)         0.00013196  0.00011349 +14.00%
gpu_(10,10,16,16)_(-1,-1)       0.00012898  0.00010705 +17.01%
gpu_(10,10,16,16)_(-1,0)        0.00014353  0.00012338 +14.04%
gpu_(10,10,16,16)_(0,-1)        0.00011599  0.00012493  -7.71%
gpu_(10,10,16,16)_(2,2)         0.00011539  0.00011349  +1.65%
gpu_(10,10,101,101)_(-1,-1)     0.00014699  0.00010252 +30.25%
gpu_(10,10,101,101)_(-1,0)       0.0002141  0.00015497 +27.62%
gpu_(10,10,101,101)_(0,-1)       0.0002017  0.00015843 +21.45%
gpu_(10,10,101,101)_(2,2)       0.00018394  0.00015402 +16.27%
gpu_(10,10,256,256)_(-1,-1)     0.00032747  9.0003e-05 +72.52%
gpu_(10,10,256,256)_(-1,0)      0.00074494  0.00040746 +45.30%
gpu_(10,10,256,256)_(0,-1)      0.00072503  0.00042391 +41.53%
gpu_(10,10,256,256)_(2,2)       0.00061846  0.00038004 +38.55%
gpu_(10,10,1000,1000)_(-1,-1)    0.0032645  0.00010896 +96.66%
gpu_(10,10,1000,1000)_(-1,0)      0.007543   0.0038971 +48.34%
gpu_(10,10,1000,1000)_(0,-1)      0.006058   0.0039405 +34.95%
gpu_(10,10,1000,1000)_(2,2)       0.005198    0.003448 +33.67%
gpu_(10,10,1024,1024)_(-1,-1)    0.0034155  9.1434e-05 +97.32%
gpu_(10,10,1024,1024)_(-1,0)      0.007099    0.004158 +41.43%
gpu_(10,10,1024,1024)_(0,-1)      0.006843    0.003849 +43.75%
gpu_(10,10,1024,1024)_(2,2)       0.005506   0.0031376 +43.02%
gpu_(10,10,2048,2048)_(-1,-1)     0.013119  0.00010097 +99.23%
gpu_(10,10,2048,2048)_(-1,0)      0.028533    0.015175 +46.81%
gpu_(10,10,2048,2048)_(0,-1)      0.028458    0.014926 +47.55%
gpu_(10,10,2048,2048)_(2,2)       0.022175    0.011797 +46.80%

PiperOrigin-RevId: 168849471

* * dataset_ops.read_batch_features() now discards keys for keyed Dataset.
* dataset_ops.read_batch_features() ignores unnecessary repeat() when num_repeat == 1.

PiperOrigin-RevId: 168855155

* Migrate TFGAN eval to opensource.

PiperOrigin-RevId: 168855880

* [XLA] Remove superfluous locking from xla::ComputationBuilder.

The class is thread compatible, not thread-safe. It is illegal to call non-const methods of the class concurrently. So the mutex is pointless.

Also mark a couple of accessors const.

PiperOrigin-RevId: 168857132

* Add ConvertGraphDefToXla to convert from GraphDef to xla::Computation.

The main logic is simply refactored from tfcompile, with some minor cleanups
along the way.

PiperOrigin-RevId: 168857174

* Bugfix to tf.contrib.seq2seq beam_search_ops: GPU edge case of seq_len == 0.

PiperOrigin-RevId: 168862288

* [tf.contrib.data] Add `batch_and_drop_remainder` transformation.

This transformation, which is designed for use with `Dataset.apply()`,
acts like the default of behavior of `tf.train.batch()`, which will
truncate a finite input source if its number of elements is not an
exact multiple of the batch size. A benefit of using this
transformation is that it gives a statically known shape to the output
elements, because they are all exactly `batch_size` in the 0th
dimension.

PiperOrigin-RevId: 168863148

* Minor renaming from tfcompile.Config to tf2xla.Config in comments.

PiperOrigin-RevId: 168863860

* Certain ops don't need eager gradients to keep their inputs / outputs alive.

PiperOrigin-RevId: 168864350

* [XLA] Add S64 while loop test.

PiperOrigin-RevId: 168865653

* tfdbg: fix a bug in list_inputs and list_outputs

wherein a tensor name like "x:1" fails to be processed because it were not converted to the node name ("x" in this example) first.

Also simplify analyzer_cli_test.py a little through a new helper function.

PiperOrigin-RevId: 168867948

* Adds multi_label_head in tf.contrib.estimator

PiperOrigin-RevId: 168873313

* Script that generates __init__.py files based on tf_api_names annotations.

PiperOrigin-RevId: 168878737

* Fixing the build command.

PiperOrigin-RevId: 168881605

* Make sure all checked threads are joined before they are terminated.

PiperOrigin-RevId: 168884294

* Output metrics in train mode for multihead.

This is to be consistent with other heads who output the metric tensors in train mode. Outputting the metric tensors allow us for example to plot the metrics on the training set (and compare them to the metircs on the eval set).

PiperOrigin-RevId: 168884726

* Automated g4 rollback of changelist 168458634

PiperOrigin-RevId: 168887778

* Adds DNNEstimator to tf.contrib.estimator.

PiperOrigin-RevId: 168887825

* [tf.contrib.data] Expose `tf.contrib.data.batch_and_drop_remainder()`.

PiperOrigin-RevId: 168888592

* disabling timeout test in opensource build

PiperOrigin-RevId: 168890483

* Add ops that perform color transforms (including changing value, saturation and hue) in YIQ space.

PiperOrigin-RevId: 168897736

* Update the minimum requirement of espsilon for batch norm.

PiperOrigin-RevId: 168897907

* Adding support for capture-by-value.

PiperOrigin-RevId: 168903482

* disabling failing tsan test

PiperOrigin-RevId: 168903876

* disable asan for test timeout

PiperOrigin-RevId: 168903999

* Internal change.

PiperOrigin-RevId: 168910187

* Fix broken test: tensorflow/contrib/eager/python:datasets_test

PiperOrigin-RevId: 168914742

* [XLA:CPU] Implement map fusion.

PiperOrigin-RevId: 168915358

* Merge changes from github.
END_PUBLIC

I also integrated #13073 by hand to make TAP happy.

---
Commit 92362d0f0 authored by Skye Wanderman-Milne<skyewm@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add WhileContext class and add plumbing for creating them.

This change introduces WhileContext, which stores information about a
while loop and will be used in future changes to generate while loop
gradient graphs. Exit nodes in a while loop now have a pointer to
their associated WhileContext. This will be used to retrieve the
context for a given loop.

This change adds an optional parameter to BuildWhileLoop() to create a
WhileContext for the while loop (currently this is always true, but
gradients will generate while loops without associated contexts). This
change also adds a as-yet-unused option to BuildWhileLoop() to return
the predicate output.

PiperOrigin-RevId: 168562303

---
Commit a4f6e7c1a authored by RJ Ryan<rjryan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add mel-scale conversion matrix support to tf.contrib.signal.

PiperOrigin-RevId: 168560255

---
Commit b00b6d23c authored by Henry Tan<henrytan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix a segmentation fault caused by invalid log directory in InternalFlush().

PiperOrigin-RevId: 168557063

---
Commit 2bc7a155a authored by Yong Tang<yong.tang.github@outlook.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Add uint16 support for tf.decode_raw (#12719)

* Add uint16 support for tf.decode_raw

This fix tries to address the request raised in 10124 where
uint16 support for tf.decode_raw is needed. tf.decode_raw
already support half, float32, float64, int8, int16, int32, int64,
uint8. And uint16 was not supported.

This fix adds uint16 support for tf.decode_raw.

This fix fixes 10124.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix test failure caused by uint16 support of decode_raw and add unit tests.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---
Commit 009285c09 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Remove benchmark for TensorShapeOld.

PiperOrigin-RevId: 168551108

---
Commit dc1eda8a6 authored by Peter Hawkins<phawkins@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Fix CHECK-failure crash if a non-tuple was passed to GetTupleElement.

PiperOrigin-RevId: 168550703

---
Commit 010922ed9 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Go: Update generated wrapper functions for TensorFlow ops.

PiperOrigin-RevId: 168549989

---
Commit c8a6131e9 authored by Mark Daoust<markdaoust@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
make `tf.sets` examples executable

Fixes #12969

PiperOrigin-RevId: 168549712

---
Commit bece65c6f authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use a map instead of a vector of Children() in the BeamEntry.

The assumption is that since the entries are sparse (they are all populated, but most are never Active()), using the map will save memory and make iterating over the Children() more efficient.

PiperOrigin-RevId: 168548814

---
Commit 0d5ab82ce authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update ops-related pbtxt files.

PiperOrigin-RevId: 168548642

---
Commit 3331c574b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Implementing gradients for tf.image.resize_bicubic.

PiperOrigin-RevId: 168547412

---
Commit 4982ef0fa authored by Martin Wicke<wicke@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add the ability to warn only once if deprecated functionality is used, and make that the default.

PiperOrigin-RevId: 168545655

---
Commit 99423416a authored by Peter Hawkins<phawkins@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Make shape inference error messages for the While HLO more readable. Build the error lazily.

PiperOrigin-RevId: 168531083

---
Commit d10374e45 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Discard some unneccessary logging commands.

PiperOrigin-RevId: 168500721

---
Commit 83cbabb85 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix wrong format of logging message.

PiperOrigin-RevId: 168497373

---
Commit eec4f1b3a authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Go: Update generated wrapper functions for TensorFlow ops.

PiperOrigin-RevId: 168494944

---
Commit 69301f352 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update ops-related pbtxt files.

PiperOrigin-RevId: 168494220

---
Commit 9d56f419c authored by Mingxing Tan<tanmingxing@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add crop_and_decode_jpeg_op that combines the crop and decode for better
performance.

PiperOrigin-RevId: 168493125

---
Commit 48ddf64d0 authored by Chris Leary<leary@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Make large params test only run in opt builds.

PiperOrigin-RevId: 168491913

---
Commit 11d3ac29d authored by Chris Leary<leary@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Add tests for large numbers of parameter / return values and while loops.

PiperOrigin-RevId: 168487225

---
Commit 3cd6bdef5 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Added test cases on R4 slice.

PiperOrigin-RevId: 168482049

---
Commit 46a81b5c3 authored by Jacques Pienaar<jpienaar@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add cast S64 to F32 test.

PiperOrigin-RevId: 168473650

---
Commit 59bdf598d authored by Derek Murray<mrry@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add an automatically-generated "tensorflow.python.platform.build_info" script.

The motivation for this script is to provide better tools for
diagnosing load-time errors (such as the ones that plague the Windows
build due to DLL issues). Note that the script is intended to be
self-contained, so that it is possible to import it without loading
the entire TensorFlow runtime.

This generated script currently contains a single symbol,
`is_cuda_build`, which records whether the build has GPU support or not.

PiperOrigin-RevId: 168471034

---
Commit c3b86347f authored by Olivia Nordquist<nolivia@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
reenabling tests that are passing

PiperOrigin-RevId: 168466361

---
Commit c728665ec authored by Henry Tan<henrytan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add const qualifiers whenever appropriate.

PiperOrigin-RevId: 168465926

---
Commit bf96fcd13 authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use the scalar cache in MeanGrad.

PiperOrigin-RevId: 168462267

---
Commit 1cada9ea2 authored by Olivia Nordquist<nolivia@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
reenabling test that passed after 100 runs w/o timing out

PiperOrigin-RevId: 168458634

---
Commit 00c865566 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Generate error (instead of segfault) when trying to copy string tensor
to GPU in EagerTensor constructor.

PiperOrigin-RevId: 168457320

---
Commit 655f26fc7 authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Resurrects autograd-free eager gradients.

PiperOrigin-RevId: 168448557

---
Commit 8f37f3002 authored by Peter Hawkins<phawkins@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[TF:XLA] Cleanups to handling of arguments during XLA compilation:
* combine resource kinds in XlaCompiler::Argument::Kind, use a separate XlaResource::Kind field to distinguish different kinds of resource.
* merge XlaContext::HandleOrConstant and XlaExpression, which were almost identical.
* remove XlaContext::Argument; instead, build XlaExpressions directly from XlaCompiler and add them to the XlaContext.

PiperOrigin-RevId: 168439341

---
Commit 7f5346a80 authored by Gunhan Gulsoy<gunan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Reduce cmake log mess.

* Echo off for the .bat scripts.
* TF cmake: disable warnings in some of the patched projects (gif,jpeg,lmdb).

PiperOrigin-RevId: 168432070

---
Commit 2ad85aa4d authored by Mark Heffernan<meheff@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use xla/tests:xla_internal_test_main for all tests under tf/compiler/xla
and remove any main() definitions in tests. This enables use of flags
in all tests.

PiperOrigin-RevId: 168424796

---
Commit cd377811d authored by Henry Tan<henrytan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Comment and error message consistency cleanup.

PiperOrigin-RevId: 168422582

---
Commit 7c19b82af authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update tf.sparse_reset_shape so that when shrinking the shape of an empty
sparse tensor, the result has a shape of all zeros.

PiperOrigin-RevId: 168419639

---
Commit fcacb40d4 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
FirstReadyManager for scheduling nodes in VirtualScheduler.
The current FIFOManager may yield inefficient scheduling; _Recv pushed to the
FIFO blocks other nodes that can run before _Recv due to the node order in FIFO.
FirstReadyManager picks a node with the earliest time_ready in the queue,
avoiding this problem.

Also, fixed VirtualPlacer to properly set device when Node's device name does not
include job name and to set GPU:0 as default device.

PiperOrigin-RevId: 168418455

---
Commit 7e47624f5 authored by Asim Shankar<ashankar@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
eager: Initial support for iteration over tf.contrib.data.Dataset objects.

TODO:
- Support function-valued operation attributes in eager
  (Required for MapDataset, FilterDataset etc. which encode the
  per-element computation in a TensorFlow function)
PiperOrigin-RevId: 168418250

---
Commit b0a397fce authored by Asim Shankar<ashankar@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
eager: Remove unnecessary TFE_Context argument to TFE_OpSetDevice.

PiperOrigin-RevId: 168417999

---
Commit 86211d554 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Graph transform to flatten atrous (dilated) convolutions (i.e., a sequence of SpaceToBatchND-Conv-BatchToSpaceND ops) to a regular Conv op with upsampled filters.

PiperOrigin-RevId: 168414124

---
Commit 3438981ca authored by David G. Andersen<dga@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Apply exported symbol filtering to the c++ API analogously to
what is filtered for the C API.
Fixes bug reported in comments on #1924

PiperOrigin-RevId: 168413719

---
Commit 7e023d865 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA:CPU] Remove code from parallel CPU backend outlining that was causing unnecessary copies to be inserted, and which is no longer necessary since we added co-located buffer support for kCall.
*) All bitcast copy is no longer necessary as CopyInsertion will insert copies
at the root of the computation for a parameter which is live-out.
*) Copy if root does not define buffer no longer necessary because colocated
assignment looks at points-to set of root instruction.

PiperOrigin-RevId: 168412076

---
Commit 5da4df92c authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Simplify some code in grappler_item_builder.cc, no change in logic.

PiperOrigin-RevId: 168409110

---
Commit 82ec6241a authored by drpngx<drpngx@users.noreply.github.com>
Committed by GitHub<noreply@github.com>:
Add six and numpy imports
---
Commit 9c4ce2452 authored by Mark Heffernan<meheff@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add flag parsing to more tests in xla/service specifically those which build
HLO graphs. This enables, for example, dumping of the graphs with
--xla_generate_hlo_graph. Also remove some superfluous tensorflow test_main
dependencies.

PiperOrigin-RevId: 168406746

---
Commit d4efa695c authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Relax the feed_nodes collection check, which triggers a false positive in some modes where the feed node collection is auto-generated. Keep it as a warning to help correct user-provided feed node lists.

PiperOrigin-RevId: 168396408

---
Commit cbc46a856 authored by Changming Sun<chasun@microsoft.com>
Committed by gunan<gunan@google.com>:
Add a missing template explicit instantiation of SetZeroFunctor (#12791)

---
Commit 7bb08f5bf authored by Kevin Slagle<kjslag@gmail.com>
Committed by drpngx<drpngx@users.noreply.github.com>:
fix ExponentialMovingAverage documentation so that ExponentialMovingAverage.apply is evaluated within control_dependencies (#12987)

---
Commit e6b011763 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Extend c++ gradient_checker to complex types.

PiperOrigin-RevId: 168392949

---
Commit 4086219a4 authored by Lyndon White<oxinabox@ucc.asn.au>
Committed by drpngx<drpngx@users.noreply.github.com>:
Correct minor typo in substr docs example (#12991)

---
Commit f63aa7f49 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Migrate core TFGAN functions to opensource.

PiperOrigin-RevId: 168391923

---
Commit bc6b60f1b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix tuple_losses bug caused by Python bug.

PiperOrigin-RevId: 168386341

---
Commit 7a8c63da3 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Migrate `leaky_relu` to `nn_ops.py`. Will be used for TFGAN.

PiperOrigin-RevId: 168386268

---
Commit f7ba16fdf authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Do not export from eval on train data steps.

PiperOrigin-RevId: 168374021

---
Commit 9b9e54b34 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Adding NCCL sum op, register all_sum gradient.
Streamlining nccl test.

PiperOrigin-RevId: 168347428

---
Commit bc300318e authored by Gunhan Gulsoy<gunan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update gemmlowp hash as the commit history seems to have changed in the
repository.

PiperOrigin-RevId: 168343607

---
Commit 1e96d54d9 authored by gunan<gunan@google.com>
Committed by GitHub<noreply@github.com>:
Also accept non-k8 CPU types in build pip package. (#12975)

* Also accept non-k8 CPU types in build pip package.
Fixes #12735

* Make the script work with `set -e`.

---
Commit c0a4c7ffc authored by Chris Leary<leary@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Fix bug in ShapeUtil::ShapeIs that would lead to type inference errors.

PiperOrigin-RevId: 168323589

---
Commit 4af9be964 authored by Amy<amy@infosleuth.net>
Committed by drpngx<drpngx@users.noreply.github.com>:
support passing in a source url to the mnist read_data_sets function, to make it easier to use 'fashion mnist' etc. (#12983)

---
Commit 9f848734f authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Tweak layer a bit to be eager friendly.

PiperOrigin-RevId: 168312865

---
Commit 60f15462b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Change conv_input_scale and side_input_scale from attributes to inputs for improved flexibility, in fused_conv2d_bias_activation op.

PiperOrigin-RevId: 168311988

---
Commit 4b4e10f9c authored by Jianwei Xie<xiejw@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Adds dict support of eval metrics.

PiperOrigin-RevId: 168310444

---
Commit ab7f22de6 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op.

PiperOrigin-RevId: 168300911

---
Commit 3a98035fa authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Augment metadata output with source-line info, as before.

PiperOrigin-RevId: 168292527

---
Commit 349188152 authored by Yao Zhang<yaozhang@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Enable fused batch norm, which is 15-20% faster for training and inference.

PiperOrigin-RevId: 168288154

---
Commit 08587d45b authored by Yuefeng Zhou<yuefengz@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Added back persistent memory tracking in queue op. The new tracking logic has avoided the crash in previous implementation:  the queue_ passed to CreateTypedQueue may be unreffed if the resource is already created by another resource op that shares the same resource name and type.

PiperOrigin-RevId: 168284509

---
Commit 733063d55 authored by Amit Patankar<amitpatankar@google.com>
Committed by Amit Patankar<amitpatankar@google.com>:
Fixing awkward wording.

---
Commit c7ad6bfef authored by Amit Patankar<amitpatankar@google.com>
Committed by Amit Patankar<amitpatankar@google.com>:
Removing accidental hash.

---
Commit 53dbc761a authored by Amit Patankar<amitpatankar@google.com>
Committed by Amit Patankar<amitpatankar@google.com>:
Adding Windows self check script to docs.

---
Commit ed1135994 authored by Andrew Harp<andrewharp@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add -latomic flag to benchmark_model target to fix Android x86 build.

PiperOrigin-RevId: 168281337

---
Commit c0348bb55 authored by Anna R<annarev@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update tf_export.py to take constant name as an argument instead of a constant.

PiperOrigin-RevId: 168280613

---
Commit c3d19e40a authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Cleanup training_ops to reduce code redudancy.

PiperOrigin-RevId: 168280069

---
Commit 123fb01ee authored by Yao Zhang<yaozhang@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Set fused=False for batch norm, because the test assumes no bessel's
correction. Fused=True would add bessel's correction to variance.

PiperOrigin-RevId: 168274392

---
Commit f0e8c545e authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Switch resource variables from copy-on-read to copy-on-write.

RELNOTES: Change the signature of (C++) GetInputTensorFromVariable in
training_op_helpers to support new copy-on-write semenatics of resource
variables.
PiperOrigin-RevId: 168273249

---
Commit 495cc8e47 authored by Yuan (Terry) Tang<terrytangyuan@users.noreply.github.com>
Committed by drpngx<drpngx@users.noreply.github.com>:
Minor wording change in timeseries module's README (#12938)

* Minor wording change in timeseries module's README

* Address comments

---
Commit f13b876ed authored by Amit Patankar<amitpatankar@google.com>
Committed by Amit Patankar<amitpatankar@google.com>:
Making the default build from source version 1.4.0dev. The whl files that are built will be 1.3.0devDDMMYYYY.

---
Commit 2356c0ff4 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day).

PiperOrigin-RevId: 168259652

---
Commit e15f4cae2 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Don't remove all aliases from linalg namespace.
Get rid of redundant aliases.

PiperOrigin-RevId: 168257658

---
Commit c58082642 authored by postBG<profile2697@gmail.com>
Committed by drpngx<drpngx@users.noreply.github.com>:
Fix minor typo in Programmers guide (#12965)

* Fix minor typo in Programmers guide

* change to "this"

---
Commit 509372c2e authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add a lot of operations' flops calculations

PiperOrigin-RevId: 168256746

---
Commit 80ed8afc0 authored by Francois Chollet<fchollet@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add Flatten to core layers.

PiperOrigin-RevId: 168254118

---
Commit a6223c01a authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix locking of variables in SparseProximalGradientDescent,
AdagradDA, SparseAdagradDA.

PiperOrigin-RevId: 168252530

---
Commit abde00830 authored by Olivia Nordquist<nolivia@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
adding InputTensor class for symmetry with OutputTensor

PiperOrigin-RevId: 168250085

---
Commit 0451032ca authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Fix variable naming style guide violation.

PiperOrigin-RevId: 168245542

---
Commit a202a5a94 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Update ops-related pbtxt files.

PiperOrigin-RevId: 168245371

---
Commit f93e354cb authored by Derek Murray<mrry@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[tf.contrib.data] Switch backend Dataset representation to DT_VARIANT.

This change introduces a new `DatasetWrapper` type that wraps a
`DatasetBase*` and can be stored in a DT_VARIANT tensor. All Dataset
ops now consume and produce DT_VARIANT instead of DT_RESOURCE, and the
underlying implementation is simplified because the `DatasetWrapper`
can be passed directly by value without using the `ResourceMgr`.

PiperOrigin-RevId: 168240571

---
Commit a4042cd2a authored by Jianwei Xie<xiejw@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Introduces the placeholder for _TrainingExecutor, which serves the implementation of tf.estimator.train_and_evaluate.

PiperOrigin-RevId: 168240151

---
Commit 10ba148f7 authored by Peter Hawkins<phawkins@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Switch control_flow_ops library to use Resource variants of Stack operators, instead of deprecated Ref variants.

PiperOrigin-RevId: 168234822

---
Commit ca43fe82b authored by Ali Yahya<alive@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
TFE: Improves the interfaces of tape.watch_variable() and implicit_grad().

tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles.

implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface.

PiperOrigin-RevId: 168232055

---
Commit b72862dfc authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
internal change

PiperOrigin-RevId: 168225993

---
Commit da3280f4d authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Re-enable tsan for sdca_estimator_test.

PiperOrigin-RevId: 168186374

---
Commit c936c1155 authored by Yifei Feng<yifeif@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix pip tests for contrib/gan.
- Add *_impl.py so tests can still access removed symbols.
- Add /python directory layer to make *_impy.py and __init__.py not in the same dir.

PiperOrigin-RevId: 168161722

---
Commit ce9a2b00f authored by Toby Boyd<tobyboyd@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Performance guide update

PiperOrigin-RevId: 168159289

---
Commit 3bce4f9a0 authored by Shanqing Cai<cais@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
TFE: expose tfe.num_gpus()

PiperOrigin-RevId: 168154345

---
Commit 67a7cbc28 authored by Jianwei Xie<xiejw@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Changed the default eval throttle secs from 2 min to 10 mins.

PiperOrigin-RevId: 168120323

---
Commit 92bed178f authored by Eugene Brevdo<ebrevdo@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Reduce cmake log mess.

* Echo off for the .bat scripts.
* TF cmake: disable warnings in some of the patched projects (gif,jpeg,lmdb).

PiperOrigin-RevId: 168119914

---
Commit 702d59582 authored by joshkyh<joshkyh@users.noreply.github.com>
Committed by Yifei Feng<fengyifei2026@gmail.com>:
Corrected hyperlink for audio training tutorial (#12923)

---
Commit 877c9deca authored by Frank Chen<frankchn@gmail.com>
Committed by Yifei Feng<fengyifei2026@gmail.com>:
Reverse change eb75ded6 so that internal tests will pass. (#12933)

As support for int64 global steps is not ready in TPUs, I am reversing this change so that our internal performance and regression tests will pass.
---
Commit 665966438 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Re-enable grpc_session_test.

PiperOrigin-RevId: 168078694

---
Commit 405def792 authored by Chris Leary<leary@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Switch CallInliner to use CallGraph::VisitNodes.

PiperOrigin-RevId: 168078645

---
Commit aba3466f1 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Exposes Head and factory methods in tf.contrib.estimator.

PiperOrigin-RevId: 168071246

---
Commit b76565b39 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Some profiler fixes and cleanup.

PiperOrigin-RevId: 168069346

---
Commit 32ffc5a81 authored by Jonas<sauercrowd@users.noreply.github.com>
Committed by Yifei Feng<fengyifei2026@gmail.com>:
Just a dot in order to be consistent (#12919)

added a dot to the `7` to make clear it's a float (like every other number)
---
Commit 0753b0c79 authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Scope the scalar cache in the context.

PiperOrigin-RevId: 168065417

---
Commit 48deb206b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Migrate TFGAN features to third_party.

PiperOrigin-RevId: 168060880

---
Commit d2ae1311f authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fixing an issue in the BUILD file of the LSH ops.

PiperOrigin-RevId: 168056645

---
Commit 2f440eda4 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Expose NumpyReader for reading timeseries data.

PiperOrigin-RevId: 168055838

---
Commit be1916ce7 authored by Daniel Grazian<dgr@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Added functionality to allow `SqlDataset` to interpret a database column as various numeric types, including several integer types and `dtypes.float64`.

PiperOrigin-RevId: 168055827

---
Commit fa2000a0b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Supporting nightly windows pip packages.

PiperOrigin-RevId: 168054959

---
Commit a263ea626 authored by Asim Shankar<ashankar@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
eager: Treat eager tensors as constants during graph construction.

Unless capturing is explicitly enabled.

PiperOrigin-RevId: 168052675

---
Commit 6e402d0d2 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Make TODO a bit more specific.

PiperOrigin-RevId: 168051381

---
Commit c779384bc authored by Daniel Grazian<dgr@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Added code example to the doc string for `SqlDataset`.

PiperOrigin-RevId: 168049037

---
Commit ff6dd474a authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use self._in_graph_mode consistently in ResourceVariable
instead of sometimes getting it from the context.

Also: fix formatting of a comment and use a more precise test to detect
if initial_value is set.
PiperOrigin-RevId: 168047258

---
Commit f331f528b authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Removes "fast paths" which are not fast in eager mode.

PiperOrigin-RevId: 168046278

---
Commit 86f1713e5 authored by Jianwei Xie<xiejw@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Introduces TrainSpec and EvalSpec.

PiperOrigin-RevId: 168040435

---
Commit c8b9e92f0 authored by Asim Shankar<ashankar@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
eager: Move "register_function" to context.py

This will allow function registration from other
modules without having to import "function.py".
(And besides, the function really does belong on the context).

PiperOrigin-RevId: 168040411

---
Commit 74137f994 authored by Shanqing Cai<cais@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fix signed int overflow issue in tensor_id.cc

When a node name has a long numeric suffix, e.g.,
"foo/y_0/gradient_debug_09684b60f2184c67b744721915034528" (as has happened with tfdbg GradientsDebugger),

the parsing algorithm in ParseTensorName() may experience signed int overflow. Replacing the types with "unsigned int" resolves the issue.

PiperOrigin-RevId: 168039195

---
Commit 450c3b562 authored by Rohan Jain<rohanj@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.

PiperOrigin-RevId: 168038285

---
Commit 82cc6529f authored by Jianwei Xie<xiejw@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Fixes the wording about StopIteration.

PiperOrigin-RevId: 168034451

---
Commit fb5588002 authored by Gunhan Gulsoy<gunan@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add a statement on install/index.md on what os are supported.

PiperOrigin-RevId: 168032996

---
Commit f83f6b9ef authored by Chris Leary<leary@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test.

PiperOrigin-RevId: 168029345

---
Commit 8988ae365 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 167916124

PiperOrigin-RevId: 168916710

* Update ops-related pbtxt files.

PiperOrigin-RevId: 168917157

* Go: Update generated wrapper functions for TensorFlow ops.

PiperOrigin-RevId: 168917534
This commit is contained in:
drpngx 2017-09-15 19:38:25 -07:00 committed by GitHub
parent 7a97dfc3ec
commit e55574f282
141 changed files with 6463 additions and 1871 deletions

View File

@ -45,11 +45,11 @@ GPU packages on all platforms will arrive soon!
**Individual whl files**
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))

View File

@ -283,7 +283,6 @@ filegroup(
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/data:all_files",
"//tensorflow/contrib/data/python/framework:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/data/python/util:all_files",
@ -418,6 +417,7 @@ filegroup(
"//tensorflow/python/profiler/internal:all_files",
"//tensorflow/python/saved_model:all_files",
"//tensorflow/python/tools:all_files",
"//tensorflow/tools/api/generator:all_files",
"//tensorflow/tools/api/golden:all_files",
"//tensorflow/tools/api/lib:all_files",
"//tensorflow/tools/api/tests:all_files",

View File

@ -127,10 +127,11 @@ Status Conv2DGrad(const Scope& scope, const Operation& op,
std::vector<int32> strides;
bool use_cudnn_on_gpu;
auto attrs = op.output(0).node()->attrs();
GetNodeAttr(attrs, "data_format", &data_format);
GetNodeAttr(attrs, "padding", &padding);
GetNodeAttr(attrs, "strides", &strides);
GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu);
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu",
&use_cudnn_on_gpu));
Conv2DBackpropInput::Attrs input_attrs;
input_attrs.DataFormat(data_format);
input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
@ -157,10 +158,10 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
std::vector<int32> strides;
std::vector<int32> ksize;
auto attrs = op.output(0).node()->attrs();
GetNodeAttr(attrs, "data_format", &data_format);
GetNodeAttr(attrs, "ksize", &ksize);
GetNodeAttr(attrs, "padding", &padding);
GetNodeAttr(attrs, "strides", &strides);
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
internal::MaxPoolGrad::Attrs grad_attrs;
grad_attrs.DataFormat(data_format);
auto dx = internal::MaxPoolGrad(scope, op.input(0),
@ -179,8 +180,8 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
string data_format;
string padding;
auto attrs = op.output(0).node()->attrs();
GetNodeAttr(attrs, "data_format", &data_format);
GetNodeAttr(attrs, "padding", &padding);
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
MaxPoolGradV2::Attrs grad_attrs;
grad_attrs.DataFormat(data_format);
auto dx = MaxPoolGradV2(scope, op.input(0),

View File

@ -4,7 +4,6 @@ package(
default_visibility = ["//visibility:private"],
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Optional runtime utilities for use by code generated by tfcompile.
@ -39,32 +38,24 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
xla_proto_library(
name = "tfcompile_proto",
srcs = ["tfcompile.proto"],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfcompile_lib",
srcs = [
"codegen.cc",
"compile.cc",
"flags.cc",
"tfcompile_util.cc",
],
hdrs = [
"codegen.h",
"compile.h",
"flags.h",
"tfcompile_util.h",
],
deps = [
":runtime", # needed by codegen to print aligned_buffer_bytes
":tfcompile_proto",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -82,7 +73,6 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
],
)
@ -99,18 +89,6 @@ cc_test(
],
)
cc_test(
name = "tfcompile_util_test",
srcs = ["tfcompile_util_test.cc"],
deps = [
":tfcompile_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_binary(
name = "tfcompile",
visibility = ["//visibility:public"],
@ -123,7 +101,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tfcompile_lib",
":tfcompile_proto",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
@ -226,7 +205,11 @@ test_suite(
tags = ["manual"],
tests = [
":benchmark_test",
":codegen_test",
":runtime_test",
":test_graph_tfadd_test",
":test_graph_tfunknownop2_test",
":test_graph_tfunknownop3_test",
":test_graph_tfunknownop_test",
"//tensorflow/compiler/aot/tests:all_tests",
],

View File

@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -35,6 +35,12 @@ namespace tfcompile {
namespace {
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
// Convert an XLA type into a C++ type.
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
switch (type) {
@ -156,7 +162,7 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
*methods += R"(
void** args() { return args_; }
@ -204,8 +210,8 @@ Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
}
// Generate methods for results (outputs).
Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
string* methods) {
Status GenResultMethods(const tf2xla::Config& config,
const xla::ProgramShape& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// Non-tuple (i.e. single-result) case.
if (config.fetch_size() != 1) {
@ -285,11 +291,26 @@ Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
return Status::OK();
}
Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
for (const tf2xla::Feed& feed : config.feed()) {
if (!feed.name().empty()) {
TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
}
}
for (const tf2xla::Fetch& fetch : config.fetch()) {
if (!fetch.name().empty()) {
TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
}
}
return Status::OK();
}
} // namespace
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index > temp_sizes.size()) {
@ -574,5 +595,29 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
return Status::OK();
}
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
if (ident.empty()) {
return errors::InvalidArgument("empty identifier: ", msg);
}
// Require that the identifier starts with a nondigit, and is composed of
// nondigits and digits, as specified in section [2.11 Identifiers] of the
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
// defined as [0-9].
//
// Technically the standard also allows for `universal-character-name`, with a
// table of allowed unicode ranges, as well as `other implementation-defined
// characters`. We disallow those here to give better error messages, at the
// expensive of being more restrictive than the standard.
if (ident[0] != '_' && !IsAlpha(ident[0])) {
return errors::InvalidArgument("illegal leading char: ", msg);
}
for (size_t pos = 1; pos < ident.size(); ++pos) {
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
return errors::InvalidArgument("illegal char: ", msg);
}
}
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
@ -37,7 +39,7 @@ struct HeaderOpts {
// GenerateHeader uses the meta-information from compile_result to generate a
// C++ header giving access to the function in the generated object file. The
// header includes API usage documentation.
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header);
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
@ -47,6 +49,10 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config,
Status ParseCppClass(const string& cpp_class, string* class_name,
std::vector<string>* namespaces);
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
// appended to error messages.
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@ -29,6 +30,41 @@ namespace tensorflow {
namespace tfcompile {
namespace {
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
TEST(ValidateCppIdent, Simple) {
TF_EXPECT_OK(ValidateCppIdent("a", ""));
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
// Make sure we didn't skip a valid letter or digit
string ident;
for (char c = 'a'; c <= 'z'; c++) {
ident.append(1, c);
}
for (char c = 'A'; c <= 'Z'; c++) {
ident.append(1, c);
}
for (char c = '0'; c <= '9'; c++) {
ident.append(1, c);
}
ident += "_";
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
}
class ParseCppClassTest : public ::testing::Test {
protected:
void ExpectOK(const string& cpp_class, const string& want_class_name,
@ -91,13 +127,13 @@ TEST(GenerateHeader, Golden) {
HeaderOpts opts;
opts.class_name = "MyClass";
opts.namespaces = {"foo", "bar"};
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("feed0");
feed->set_name("myfeed");
feed = config.add_feed();
feed->mutable_id()->set_node_name("feed1");
Fetch* fetch = config.add_fetch();
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;

View File

@ -15,326 +15,32 @@ limitations under the License.
#include "tensorflow/compiler/aot/compile.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace tfcompile {
const char* const kArgOp = "_Arg";
const char* const kRetvalOp = "_Retval";
const char* const kFeedIdAttr = "_feed_id";
const char* const kFetchIdAttr = "_fetch_id";
const char* const kShapeAttr = "_shape";
const char* const kDebugNameAttr = "_debug_name";
namespace {
Status DumpGraph(const MainFlags& flags, const string& name,
const Graph& graph) {
if (flags.debug_dir.empty()) {
return Status::OK();
}
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
return WriteTextProto(Env::Default(), file, graph_def);
}
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
// tensor with a placeholder. For each feed tensor, replaces all edges so they
// point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<Feed>& feeds,
const std::unordered_map<string, string>& feed_remapping) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const Feed& feed = feeds[arg_index];
// All feeds have been replaced by placeholders.
const int output_index = 0;
const string key = TensorIdToString(feed.id());
const auto remap_it = feed_remapping.find(key);
auto node_it = node_map.find(remap_it->second);
if (node_it == node_map.end()) {
// Strip off the aot_feed_#/ prefix.
StringPiece name(remap_it->second);
const auto index = name.find('/');
if (index > 0) name.remove_prefix(index + 1);
return errors::InvalidArgument(
"Node is fed but not needed for fetching: ", name);
}
const Node* feed_node = node_it->second;
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
// "_shape" attr if we can determine it. That way the graph will be
// initialized with whatever shapes we can infer, while the user can still
// explicitly specify or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
.Attr("T", BaseType(feed_node->output_type(output_index)))
.Attr("index", arg_index)
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
// Collects out-edges from the feed node that have a matching edge index;
// these will be replaced with edges from the arg node instead.
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
if (edge->src_output() == output_index) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
}
}
return Status::OK();
}
// Each fetch id identifies the positional output of some node. For each fetch
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<Fetch>& fetches,
std::unordered_set<const Node*>* retval_nodes) {
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
const TensorId& id = fetches[ret_index].id();
auto it = node_map.find(id.node_name());
if (it == node_map.end()) {
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
}
Node* fetch_node = it->second;
if (id.output_index() >= fetch_node->num_outputs()) {
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
", output index should be < ",
fetch_node->num_outputs());
}
// Connects fetch_node -> retval_node.
Node* retval_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
.Input(fetch_node, id.output_index())
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
.Attr("index", ret_index)
.Attr(kFetchIdAttr, TensorIdToString(id))
.Finalize(graph, &retval_node));
retval_nodes->insert(retval_node);
}
return Status::OK();
}
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
Status RewriteAndPruneGraph(
Graph* graph, const Config& config,
const std::unordered_map<string, string>& feed_remapping,
const MainFlags& flags) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
TF_RETURN_IF_ERROR(
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
PruneForReverseReachability(graph, retval_nodes);
FixupSourceAndSinkEdges(graph);
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
std::set<string> missing_feeds, missing_fetches;
for (const Feed& feed : config.feed()) {
missing_feeds.insert(TensorIdToString(feed.id()));
}
for (const Fetch& fetch : config.fetch()) {
missing_fetches.insert(TensorIdToString(fetch.id()));
}
for (const Node* n : graph->op_nodes()) {
if (n->type_string() == kArgOp) {
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp,
" node found with unknown feed id: ", feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp,
" node found with unknown fetch id: ", fetch_id);
}
}
}
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted(
"Post graph-pruning",
", missing feeds: ", str_util::Join(missing_feeds, ", "),
", missing fetches: ", str_util::Join(missing_fetches, ", "));
}
return Status::OK();
}
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
// sanity-checking to ensure the index and type attributes of each node are
// initialized correctly.
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
std::map<int, Node*> indexed_arg_nodes;
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
auto insert_result = indexed_arg_nodes.insert({index, n});
if (!insert_result.second) {
const Node* dup = insert_result.first->second;
return errors::InvalidArgument(
"Multiple ", kArgOp, " nodes with index ", index, ", ",
n->DebugString(), " and ", dup->DebugString());
}
}
}
arg_nodes->clear();
for (const auto& index_node : indexed_arg_nodes) {
if (index_node.first != arg_nodes->size()) {
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
arg_nodes->size(), ", but got index ",
index_node.first);
}
arg_nodes->push_back(index_node.second);
}
return Status::OK();
}
// Fills in xla_args from the corresponding _Arg nodes in the graph.
Status CreateXlaArgs(const Graph& graph,
std::vector<XlaCompiler::Argument>* xla_args) {
std::vector<Node*> arg_nodes;
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
for (const Node* node : arg_nodes) {
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
TensorShape shape;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
return Status::OK();
}
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(xla::CompileOnlyClient* client,
std::unique_ptr<Graph> graph,
xla::Computation* computation, bool* has_context_arg) {
// Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterCompilationKernels();
// Populate the context with args from the graph.
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
}
std::vector<XlaCompiler::Argument> xla_args;
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
DeviceType device_type(DEVICE_CPU_XLA_JIT);
compiler_options.device_type = &device_type;
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
XlaCompiler compiler(compiler_options);
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
"tfcompile", std::move(graph),
xla_args, &result));
*has_context_arg = result.requires_runtime_context;
*computation = std::move(*result.computation);
int num_const_results = 0;
for (int i = 0; i < result.outputs.size(); ++i) {
// Ending up with const results (i.e. output args) is an error, since it
// means that one or more fetches that the user specified will be dropped
// from the generated function. It's most likely a configuration error,
// since the user shouldn't be asking for output args that end up as consts.
//
// TODO(toddw): Provide a way for the user to access const output args,
// e.g. perhaps hard-coded into the header, or somehow copied into the
// output buffers.
if (result.outputs[i].is_constant) {
++num_const_results;
LOG(ERROR) << "ConstRetVal index:" << i
<< " value:" << result.outputs[i].constant_value.DebugString();
}
}
if (num_const_results > 0) {
return errors::Unimplemented(
"Conversion from TensorFlow graph to XLA resulted in ",
num_const_results,
" constant results. The configuration of "
"the output args (i.e. fetch ids) is probably wrong.");
}
if (computation->IsNull()) {
return errors::Aborted(
"Conversion from TensorFlow graph to XLA resulted in an empty "
"computation.");
}
return Status::OK();
}
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
@ -376,41 +82,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, std::unique_ptr<Graph>* graph) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
// Replace references to fed tensors with references to newly added
// placeholders.
GraphDef first_copy_def = graph_def;
// Maps from name:port of a feed to the name:port of the placeholder to use.
std::unordered_map<string, string> feed_remapping;
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
&feed_remapping, &first_copy_def));
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
GraphDef second_copy_def;
TF_RETURN_IF_ERROR(
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&second_copy_def, *g->op_registry(), 0 /*node_offset*/));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(
RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
*graph = std::move(g);
return Status::OK();
}
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
CompileResult* compile_result) {
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
@ -421,8 +94,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
&compile_result->has_context_arg));
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
&computation,
&compile_result->has_context_arg));
if (!flags.debug_dir.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
computation.Snapshot());

View File

@ -18,46 +18,16 @@ limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace tfcompile {
// Constants for op types and attribute names.
extern const char* const kArgOp;
extern const char* const kRetvalOp;
extern const char* const kFeedIdAttr;
extern const char* const kFetchIdAttr;
extern const char* const kShapeAttr;
extern const char* const kDebugNameAttr;
// InitGraph creates a graph based on the graph_def, that may then be compiled
// by CompileGraph.
//
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
// and outputs of the function that will be compiled. Each feed id causes a new
// _Arg node to be created, where we first collect all existing edges pointing
// from the named node's output index, and then rewrite them to point from that
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
// with a new edge pointing from the named node's output index to that _Retval
// node. All _Retval nodes also point to a special CompileExpressions node,
// used internally to finish the compilation.
//
// The rewritten graph is then pruned to only contain the portion necessary to
// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
// for debugging.
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, std::unique_ptr<Graph>* graph);
// CompileResult describes the output of CompileGraph, where the object file
// data and meta-information is available in aot.
struct CompileResult {
@ -69,20 +39,12 @@ struct CompileResult {
int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph into an object file containing a function
// CompileGraph compiles the graph_def into an object file containing a function
// that performs the graph operations.
//
// The graph must have _Arg and _Retval nodes representing the function inputs
// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
// value=TensorShape) representing the static shape of that input, and every
// _Retval node must point to a CompileExpressions node.
//
// Typically InitGraph is called to perform this initialization, followed by
// full specification of the shape attributes.
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
CompileResult* result);
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -13,9 +13,11 @@ test_suite(
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
":tfcompile_test",
],
)
@ -90,6 +92,15 @@ tf_library(
tags = ["manual"],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfgather",
testonly = 1,
@ -117,15 +128,6 @@ tf_library(
tags = ["manual"],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfsplits",
testonly = 1,

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "params" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {

View File

@ -41,7 +41,7 @@ def tf_library(name, graph, config,
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
config: File containing tensorflow.tfcompile.Config proto. If the file ends
config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
@ -54,8 +54,7 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& kind, const string& fname,
protobuf::Message* proto) {
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (StringPiece(fname).ends_with(".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
@ -63,23 +62,17 @@ Status ReadProtoFile(const string& kind, const string& fname,
}
}
void ParseTensorId(const string& name, TensorId* id) {
const std::pair<StringPiece, int> name_index = ParseTensorName(name);
id->set_node_name(name_index.first.ToString());
id->set_output_index(name_index.second);
}
Status Main(const MainFlags& flags) {
// Process config.
Config config;
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const Fetch& fetch : config.fetch()) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << str_util::Join(nodes, ",");
@ -91,12 +84,9 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
std::unique_ptr<Graph> graph;
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph));
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result));
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();

View File

@ -1,4 +1,4 @@
# Text form of tensorflow.tfcompile.Config proto.
# Text form of tensorflow.tf2xla.Config proto.
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }

View File

@ -21,6 +21,40 @@ package(
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
xla_proto_library(
name = "tf2xla_proto",
srcs = ["tf2xla.proto"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tf2xla",
srcs = ["tf2xla.cc"],
hdrs = ["tf2xla.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":dump_graph",
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "xla_compiler",
@ -96,6 +130,51 @@ cc_library(
# Internal targets below this point.
cc_library(
name = "tf2xla_util",
srcs = ["tf2xla_util.cc"],
hdrs = ["tf2xla_util.h"],
deps = [
":tf2xla_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "tf2xla_util_test",
srcs = ["tf2xla_util_test.cc"],
deps = [
":tf2xla_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_test(
name = "tf2xla_test",
srcs = ["tf2xla_test.cc"],
deps = [
":tf2xla",
":tf2xla_proto",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],

View File

@ -0,0 +1,370 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
const char* const kArgOp = "_Arg";
const char* const kRetvalOp = "_Retval";
const char* const kFeedIdAttr = "_feed_id";
const char* const kFetchIdAttr = "_fetch_id";
const char* const kShapeAttr = "_shape";
const char* const kDebugNameAttr = "_debug_name";
namespace {
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
// tensor with a placeholder. For each feed tensor, replaces all edges so they
// point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
const std::unordered_map<string, string>& feed_remapping) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const tf2xla::Feed& feed = feeds[arg_index];
// All feeds have been replaced by placeholders.
const int output_index = 0;
const string key = TensorIdToString(feed.id());
const auto remap_it = feed_remapping.find(key);
auto node_it = node_map.find(remap_it->second);
if (node_it == node_map.end()) {
// Strip off the aot_feed_#/ prefix.
StringPiece name(remap_it->second);
const auto index = name.find('/');
if (index > 0) name.remove_prefix(index + 1);
return errors::InvalidArgument(
"Node is fed but not needed for fetching: ", name);
}
const Node* feed_node = node_it->second;
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
// "_shape" attr if we can determine it. That way the graph will be
// initialized with whatever shapes we can infer, while the user can still
// explicitly specify or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
.Attr("T", BaseType(feed_node->output_type(output_index)))
.Attr("index", arg_index)
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
// Collects out-edges from the feed node that have a matching edge index;
// these will be replaced with edges from the arg node instead.
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
if (edge->src_output() == output_index) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
}
}
return Status::OK();
}
// Each fetch id identifies the positional output of some node. For each fetch
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
std::unordered_set<const Node*>* retval_nodes) {
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
const tf2xla::TensorId& id = fetches[ret_index].id();
auto it = node_map.find(id.node_name());
if (it == node_map.end()) {
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
}
Node* fetch_node = it->second;
if (id.output_index() >= fetch_node->num_outputs()) {
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
", output index should be < ",
fetch_node->num_outputs());
}
// Connects fetch_node -> retval_node.
Node* retval_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
.Input(fetch_node, id.output_index())
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
.Attr("index", ret_index)
.Attr(kFetchIdAttr, TensorIdToString(id))
.Finalize(graph, &retval_node));
retval_nodes->insert(retval_node);
}
return Status::OK();
}
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
Status RewriteAndPruneGraph(
Graph* graph, const tf2xla::Config& config,
const std::unordered_map<string, string>& feed_remapping) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
TF_RETURN_IF_ERROR(
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
VLOG(2) << "Post rewrite: "
<< dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, retval_nodes);
FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: "
<< dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
std::set<string> missing_feeds, missing_fetches;
for (const tf2xla::Feed& feed : config.feed()) {
missing_feeds.insert(TensorIdToString(feed.id()));
}
for (const tf2xla::Fetch& fetch : config.fetch()) {
missing_fetches.insert(TensorIdToString(fetch.id()));
}
for (const Node* n : graph->op_nodes()) {
if (n->type_string() == kArgOp) {
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp,
" node found with unknown feed id: ", feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp,
" node found with unknown fetch id: ", fetch_id);
}
}
}
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted(
"Post graph-pruning",
", missing feeds: ", str_util::Join(missing_feeds, ", "),
", missing fetches: ", str_util::Join(missing_fetches, ", "));
}
return Status::OK();
}
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
// sanity-checking to ensure the index and type attributes of each node are
// initialized correctly.
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
std::map<int, Node*> indexed_arg_nodes;
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
auto insert_result = indexed_arg_nodes.insert({index, n});
if (!insert_result.second) {
const Node* dup = insert_result.first->second;
return errors::InvalidArgument(
"Multiple ", kArgOp, " nodes with index ", index, ", ",
n->DebugString(), " and ", dup->DebugString());
}
}
}
arg_nodes->clear();
for (const auto& index_node : indexed_arg_nodes) {
if (index_node.first != arg_nodes->size()) {
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
arg_nodes->size(), ", but got index ",
index_node.first);
}
arg_nodes->push_back(index_node.second);
}
return Status::OK();
}
// Fills in xla_args from the corresponding _Arg nodes in the graph.
Status CreateXlaArgs(const Graph& graph,
std::vector<XlaCompiler::Argument>* xla_args) {
std::vector<Node*> arg_nodes;
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
for (const Node* node : arg_nodes) {
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
TensorShape shape;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
return Status::OK();
}
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
xla::Computation* computation,
bool* requires_runtime_context) {
// Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterCompilationKernels();
// Populate the context with args from the graph.
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
}
std::vector<XlaCompiler::Argument> xla_args;
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
DeviceType device_type(DEVICE_CPU_XLA_JIT);
compiler_options.device_type = &device_type;
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
XlaCompiler compiler(compiler_options);
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
"tfcompile", std::move(graph),
xla_args, &result));
*requires_runtime_context = result.requires_runtime_context;
*computation = std::move(*result.computation);
int num_const_results = 0;
for (int i = 0; i < result.outputs.size(); ++i) {
// Ending up with const results (i.e. output args) is an error, since it
// means that one or more fetches that the user specified will be dropped
// from the generated function. It's most likely a configuration error,
// since the user shouldn't be asking for output args that end up as consts.
//
// TODO(toddw): Provide a way for the user to access const output args,
// e.g. perhaps hard-coded into the header, or somehow copied into the
// output buffers.
if (result.outputs[i].is_constant) {
++num_const_results;
LOG(ERROR) << "ConstRetVal index:" << i
<< " value:" << result.outputs[i].constant_value.DebugString();
}
}
if (num_const_results > 0) {
return errors::Unimplemented(
"Conversion from TensorFlow graph to XLA resulted in ",
num_const_results,
" constant results. The configuration of "
"the output args (i.e. fetch ids) is probably wrong.");
}
if (computation->IsNull()) {
return errors::Aborted(
"Conversion from TensorFlow graph to XLA resulted in an empty "
"computation.");
}
return Status::OK();
}
// InitGraph creates a graph based on the graph_def, that may then be converted
// to an xla::Computation via ConvertGraphToXla.
//
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
// and outputs of the function that will be compiled. Each feed id causes a new
// _Arg node to be created, where we first collect all existing edges pointing
// from the named node's output index, and then rewrite them to point from that
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
// with a new edge pointing from the named node's output index to that _Retval
// node.
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
std::unique_ptr<Graph>* graph) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
// Replace references to fed tensors with references to newly added
// placeholders.
GraphDef first_copy_def = graph_def;
// Maps from name:port of a feed to the name:port of the placeholder to use.
std::unordered_map<string, string> feed_remapping;
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
&feed_remapping, &first_copy_def));
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
GraphDef second_copy_def;
TF_RETURN_IF_ERROR(
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&second_copy_def, *g->op_registry(), /*node_offset=*/0));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
*graph = std::move(g);
return Status::OK();
}
} // namespace
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
xla::Computation* computation,
bool* requires_runtime_context) {
std::unique_ptr<Graph> graph;
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation,
requires_runtime_context));
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,43 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/core/framework/graph.pb.h"
namespace tensorflow {
// Converts a tensorflow::GraphDef into an xla::Computation. The given `config`
// specifies the portion of the graph to convert, via feeds and fetches. Each
// feed is a positional input argument for the generated computation, while each
// fetch is a positional output argument.
//
// The computation is built in the context of the given `client`, which may
// subsequently be used to compile or execute the computation.
//
// If `requires_runtime_context` is filled with true, this indicates the last
// argument of the computation is XlaLocalRuntimeContext*.
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
xla::Computation* computation,
bool* requires_runtime_context);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_

View File

@ -1,10 +1,10 @@
syntax = "proto3";
package tensorflow.tfcompile;
package tensorflow.tf2xla;
option cc_enable_arenas = true;
option java_outer_classname = "CompileProtos";
option java_outer_classname = "Tf2XlaProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.tfcompile";
option java_package = "org.tensorflow.tf2xla";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
@ -19,32 +19,32 @@ message TensorId {
};
// Feed represents a single feed tensor in the graph, which corresponds to an
// input argument for the generated function.
// input argument for the generated computation.
message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
// Optional data type. This is not normally required, as the graph itself
// contains this information. However, if the node being fed is an op that
// is not linked into the tfcompile binary, then the type cannot be inferred
// from the node; in this case, the type should be set here.
// contains this information. However, if the node being fed is an op that is
// not linked into the binary, then the type cannot be inferred from the node;
// in this case, the type should be set here.
DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
// output argument for the generated function.
// output argument for the generated computation.
message Fetch {
TensorId id = 1;
string name = 2; // Optional name for generated code.
};
// Config represents configuration information for tfcompile.
// Config represents configuration information for tf2xla conversion.
message Config {
// Each feed is a positional input argument for the generated function. The
// order of each entry matches the order of each input argument.
// Each feed is a positional input argument for the generated computation.
// The order of each entry matches the order of each input argument.
repeated Feed feed = 1;
// Each fetch is a positional output argument for the generated function. The
// order of each entry matches the order of each output argument.
// Each fetch is a positional output argument for the generated computation.
// The order of each entry matches the order of each output argument.
repeated Fetch fetch = 2;
};

View File

@ -0,0 +1,99 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
AttrValue TypeAttrValue(DataType type) {
AttrValue attr_value;
SetAttrValue(type, &attr_value);
return attr_value;
}
GraphDef SumGraph() {
GraphDef graph_def;
NodeDef* x = graph_def.add_node();
x->set_name("x");
x->set_op("Placeholder");
(*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
NodeDef* y = graph_def.add_node();
y->set_name("y");
y->set_op("Placeholder");
(*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
NodeDef* sum = graph_def.add_node();
sum->set_name("sum");
sum->set_op("Add");
sum->add_input("x");
sum->add_input("y");
(*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32);
return graph_def;
}
tf2xla::Config SumConfig() {
tf2xla::Config config;
config.add_feed()->mutable_id()->set_node_name("x");
config.add_feed()->mutable_id()->set_node_name("y");
config.add_fetch()->mutable_id()->set_node_name("sum");
return config;
}
TEST(ConvertGraphDefToXla, Sum) {
GraphDef graph_def = SumGraph();
tf2xla::Config config = SumConfig();
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
xla::Computation computation;
bool requires_runtime_context;
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation,
&requires_runtime_context));
ASSERT_FALSE(requires_runtime_context);
// Set up arguments.
auto x_literal = xla::Literal::CreateR0<int32>(10);
auto y_literal = xla::Literal::CreateR0<int32>(32);
auto x_global_or = client->TransferToServer(*x_literal);
auto y_global_or = client->TransferToServer(*y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
std::move(x_global_or.ValueOrDie());
std::unique_ptr<xla::GlobalData> y_global =
std::move(y_global_or.ValueOrDie());
// Execute and check result.
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
EXPECT_EQ("42", result->ToString());
}
} // namespace
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue>
#include <set>
#include <unordered_map>
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -29,21 +29,13 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
namespace {
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
Status ValidateTensorId(const TensorId& id) {
Status ValidateTensorId(const tf2xla::TensorId& id) {
if (id.node_name().empty()) {
return errors::InvalidArgument("TensorId node_name must be non-empty");
}
@ -53,10 +45,9 @@ Status ValidateTensorId(const TensorId& id) {
return Status::OK();
}
Status ValidateFeedFetchName(const string& kind, const string& name,
std::set<string>* names) {
Status CheckNameDuplicates(const string& kind, const string& name,
std::set<string>* names) {
if (!name.empty()) {
TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
if (!names->insert(name).second) {
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
}
@ -80,42 +71,18 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
if (ident.empty()) {
return errors::InvalidArgument("empty identifier: ", msg);
}
// Require that the identifier starts with a nondigit, and is composed of
// nondigits and digits, as specified in section [2.11 Identifiers] of the
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
// defined as [0-9].
//
// Technically the standard also allows for `universal-character-name`, with a
// table of allowed unicode ranges, as well as `other implementation-defined
// characters`. We disallow those here to give better error messages, at the
// expensive of being more restrictive than the standard.
if (ident[0] != '_' && !IsAlpha(ident[0])) {
return errors::InvalidArgument("illegal leading char: ", msg);
}
for (size_t pos = 1; pos < ident.size(); ++pos) {
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
return errors::InvalidArgument("illegal char: ", msg);
}
}
return Status::OK();
}
Status ValidateConfig(const Config& config) {
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const Feed& feed : config.feed()) {
for (const tf2xla::Feed& feed : config.feed()) {
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
names.clear();
for (const Fetch& fetch : config.fetch()) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
if (config.feed().empty() || config.fetch().empty()) {
@ -125,10 +92,10 @@ Status ValidateConfig(const Config& config) {
}
Status AddPlaceholdersForFeeds(
const Config& config, const OpRegistryInterface* op_registry,
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
struct PlaceholderInfo {
const Feed* feed = nullptr; // point to Feed in <config>.
const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
string placeholder_name;
DataType data_type = DT_INVALID;
};
@ -137,9 +104,9 @@ Status AddPlaceholdersForFeeds(
// when creating placeholders (genrules want deterministic output).
std::map<string, PlaceholderInfo> placeholder_info;
for (int i = 0; i < config.feed_size(); ++i) {
const Feed* feed = &config.feed(i);
const tf2xla::Feed* feed = &config.feed(i);
const string name_port = TensorIdToString(feed->id());
auto& info = placeholder_info[name_port];
PlaceholderInfo& info = placeholder_info[name_port];
info.feed = feed;
info.placeholder_name = strings::StrCat(
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
@ -153,7 +120,7 @@ Status AddPlaceholdersForFeeds(
}
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
PlaceholderInfo& info = it->second;
const TensorId& feed_id = info.feed->id();
const tf2xla::TensorId& feed_id = info.feed->id();
// Find the existing node and determine data type.
auto node_it = name_to_node.find(feed_id.node_name());
@ -214,16 +181,16 @@ Status AddPlaceholdersForFeeds(
return Status::OK();
}
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
// Tensors needed for feeding.
std::set<std::pair<string, int>> feed_tensors;
for (const auto& feed_config : config.feed()) {
feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
feed_config.id().output_index()));
for (const tf2xla::Feed& feed : config.feed()) {
feed_tensors.insert(
std::make_pair(feed.id().node_name(), feed.id().output_index()));
}
// Maps node name to reachability.
@ -279,9 +246,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
string TensorIdToString(const TensorId& id) {
string TensorIdToString(const tf2xla::TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -13,26 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#include <unordered_map>
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
// appended to error messages.
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const Config& config);
Status ValidateConfig(const tf2xla::Config& config);
// Modifies <graph_def> to include placeholders for each fed tensor, and
// update references to the fed tensors to refer to the placeholders.
@ -40,18 +34,17 @@ Status ValidateConfig(const Config& config);
// (except where their input edges are modified by the replacement of other
// feeds).
Status AddPlaceholdersForFeeds(
const Config& config, const OpRegistryInterface* op_registry,
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out);
// Returns node:port for the given <id>.
string TensorIdToString(const TensorId& id);
string TensorIdToString(const tf2xla::TensorId& id);
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
void ExpectErrorContains(const Status& status, StringPiece str) {
@ -32,45 +31,16 @@ void ExpectErrorContains(const Status& status, StringPiece str) {
<< "expected error: " << status.error_message() << " to contain: " << str;
}
TEST(ValidateCppIdent, Simple) {
TF_EXPECT_OK(ValidateCppIdent("a", ""));
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
// Make sure we didn't skip a valid letter or digit
string ident;
for (char c = 'a'; c <= 'z'; c++) {
ident.append(1, c);
}
for (char c = 'A'; c <= 'Z'; c++) {
ident.append(1, c);
}
for (char c = '0'; c <= '9'; c++) {
ident.append(1, c);
}
ident += "_";
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
}
TEST(ValidateConfig, Good) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(123);
feed->set_name("foo_debug");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->mutable_id()->set_output_index(0);
Fetch* fetch = config.add_fetch();
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->mutable_id()->set_output_index(456);
fetch->set_name("baz_debug");
@ -81,62 +51,62 @@ TEST(ValidateConfig, Good) {
}
TEST(ValidateConfig, BadEmpty) {
Config config;
tf2xla::Config config;
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFeed) {
Config config;
Fetch* fetch = config.add_fetch();
tf2xla::Config config;
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFetch) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadFeedNodeName) {
Config config;
tf2xla::Config config;
config.add_feed();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFeedOutputIndex) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, BadFetchNodeName) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
config.add_fetch();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFetchOutputIndex) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, DuplicateFeedName) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("dup");
feed = config.add_feed();
@ -146,10 +116,10 @@ TEST(ValidateConfig, DuplicateFeedName) {
}
TEST(ValidateConfig, DuplicateFetchName) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("dup");
fetch = config.add_fetch();
@ -159,8 +129,8 @@ TEST(ValidateConfig, DuplicateFetchName) {
}
TEST(ValidateConfig, ConflictingFeedName) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("conflict");
feed = config.add_feed();
@ -170,10 +140,10 @@ TEST(ValidateConfig, ConflictingFeedName) {
}
TEST(ValidateConfig, ConflictingFetchName) {
Config config;
Feed* feed = config.add_feed();
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("conflict");
fetch = config.add_fetch();
@ -182,8 +152,8 @@ TEST(ValidateConfig, ConflictingFetchName) {
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
}
static Config FetchesConfig(std::vector<string> fetches) {
Config config;
static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
tf2xla::Config config;
for (const auto& fetch_node_name : fetches) {
auto* fetch = config.add_fetch();
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
@ -242,5 +212,4 @@ TEST(PruneGraphDefInto, Basic) {
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -1703,7 +1703,6 @@ StatusOr<Computation> ComputationBuilder::Build() {
}
void ComputationBuilder::AddOpMetadata(OpRequest* request) const {
tensorflow::mutex_lock lock(mutex_);
*request->mutable_metadata() = metadata_;
}

View File

@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@ -57,10 +56,10 @@ class ComputationBuilder {
~ComputationBuilder();
// Returns the client the builder was initialized with.
Client* client() { return client_; }
Client* client() const { return client_; }
// Returns the computation name.
const string& name() { return name_; }
const string& name() const { return name_; }
// Sets OpMetadata that will be added to all instructions until cleared.
//
@ -69,13 +68,11 @@ class ComputationBuilder {
// instructions generated via this Computation Builder will have the same
// OpMetadata attached until a call to ClearOpMetdata.
void SetOpMetadata(const OpMetadata& metadata) {
tensorflow::mutex_lock lock(mutex_);
metadata_ = metadata;
}
// Clears the HloMetdata state.
void ClearOpMetadata() {
tensorflow::mutex_lock lock(mutex_);
metadata_.Clear();
}
@ -826,15 +823,12 @@ class ComputationBuilder {
Client* client_;
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_{false};
// Mutex to guard against concurrent access to metadata_.
mutable tensorflow::mutex mutex_;
bool die_immediately_on_error_ = false;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
// throughout the TensorFlow op kernel implementations).
OpMetadata metadata_ GUARDED_BY(mutex_);
OpMetadata metadata_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};

View File

@ -180,15 +180,18 @@ cc_library(
cc_library(
name = "ir_emitter",
srcs = ["ir_emitter.cc"],
srcs = [
"elemental_ir_emitter.cc",
"ir_emitter.cc",
],
hdrs = [
"elemental_ir_emitter.h",
"ir_emitter.h",
],
deps = [
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
":elemental_ir_emitter",
":ir_emission_utils",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
@ -525,22 +528,6 @@ cc_library(
],
)
cc_library(
name = "elemental_ir_emitter",
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"@llvm//:core",
],
)
cc_library(
name = "ir_emission_utils",
srcs = ["ir_emission_utils.cc"],

View File

@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
// Producer or consumer cannot be Map. Maps are technically elementwise but
// of a slightly different form (call instead of a computation). These are not
// yet supported in the CPU backend.
if (producer->opcode() == HloOpcode::kMap ||
consumer->opcode() == HloOpcode::kMap) {
return false;
}
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion &&

View File

@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest {
std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
expected_opcodes);
}
HloComputation* CreateAdderToOne(HloModule* module) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
HloInstruction* one = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
return module->AddEmbeddedComputation(builder.Build());
}
HloComputation* CreateMax(HloModule* module) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
HloInstruction* arg1 =
builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "arg1"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
return module->AddEmbeddedComputation(builder.Build());
}
};
TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
HloOpcode::kParameter});
}
TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
HloInstruction* exp = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
builder.AddInstruction(HloInstruction::CreateMap(
shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
module->AddEntryComputation(builder.Build());
RunFusionAndCheckOpcodesWereFused(
module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
}
TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, shape, "param"));
HloInstruction* exp0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
builder.AddInstruction(HloInstruction::CreateMap(
shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
module->AddEntryComputation(builder.Build());
RunFusionAndCheckOpcodesWereFused(
module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
}
} // namespace
} // namespace cpu
} // namespace xla

View File

@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const {
if (hlo->opcode() == HloOpcode::kMap) {
return [this, hlo, &operand_to_generator](
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> operands;
for (int i = 0; i < hlo->operand_count(); i++) {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(i))(
ElementwiseSourceIndex(index, *hlo, 0)));
operands.push_back(operand_value);
}
return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
hlo->to_apply(), operands,
llvm_ir::IrName(hlo));
};
}
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
}
} // namespace cpu
} // namespace xla

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -29,12 +30,19 @@ namespace cpu {
class CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
llvm::IRBuilder<>* ir_builder, llvm::Module* module)
: ElementalIrEmitter(module_config, module, ir_builder) {}
IrEmitter* ir_emitter, llvm::Module* module)
: ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()),
ir_emitter_(ir_emitter) {}
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
IrEmitter* ir_emitter_;
};
} // namespace cpu

View File

@ -136,6 +136,10 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
const bool single_threaded_eigen =
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
// This is the point at which it is better to call into Eigen and shard the
// dot across multiple worker threads. This is a rough estimate by running
// a matmult benchmark on my local machine, and it can be tuned further.
const int64 kMaxSingleThreadedFlops = 16 * 1024;
const int64 M = result_shape.dimensions(0);

View File

@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArrayForOp(operand));
}
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
module_);
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@ -2737,14 +2736,10 @@ llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) {
}
prof_counter_idx = it->second;
uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo);
counter_name = tensorflow::strings::StrCat(
"prof_counter_0x",
tensorflow::strings::Hex(
hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address))));
counter_name = IrName("prof_counter", hlo->name());
} else {
prof_counter_idx = hlo_to_profile_idx_->size();
counter_name = "prof_counter_computation";
counter_name = "prof_counter.computation";
}
return ir_builder_.CreateGEP(GetProfileCountersArgument(),
ir_builder_.getInt64(prof_counter_idx),
@ -3180,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
};
}
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
module_);
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
return EmitTargetElementLoop(
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
std::vector<llvm::Value*> argument_addrs;
for (auto argument : arguments) {
llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
argument->getType(), "arg_addr", &ir_builder_);
ir_builder_.CreateStore(argument, argument_addr);
argument_addrs.push_back(argument_addr);
}
return EmitElementFunctionCall(llvm_function,
ShapeUtil::MakeShape(return_type, {}),
argument_addrs, name);
}
unsigned TargetMachineFeatures::largest_register_size_in_bytes(
llvm::Function* function) {
auto itr = largest_register_size_in_bytes_.find(function);

View File

@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation,
std::vector<const HloInstruction*>* instruction_order);
llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
// Emits a call to `computation` with scalar arguments `arguments`.
StatusOr<llvm::Value*> EmitScalarCall(
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
protected:
//
// The following methods implement the DfsHloVisitor interface.

View File

@ -76,10 +76,11 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Since CUDA 9.0, all GPU versions are included in a single file
const char* unified_libdevice_filename = "libdevice.10.bc";
std::vector<string> unified_libdevice_files;
tensorflow::Env::Default()->GetMatchingPaths(
const tensorflow::Status status =
tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
&unified_libdevice_files);
if( unified_libdevice_files.size() == 1 ) {
if (status.ok() && unified_libdevice_files.size() == 1) {
return unified_libdevice_filename;
}
// There are only four libdevice files: compute_{20,30,35,50}. Each GPU

View File

@ -77,7 +77,7 @@ class HloOrdering {
// Precondition: 'a' and 'b' are in the same computation.
//
// Derived classes should implement this method for determining order of
// instructions in the same comptuation. ExecutesBefore() analyzes the
// instructions in the same computation. ExecutesBefore() analyzes the
// callgraph and uses this method to determine ordering of instructions in
// different computations.
virtual bool ExecutesBeforeInSameComputation(

View File

@ -50,7 +50,7 @@ class WhileTest : public ClientLibraryTestBase {};
// while (result < 5) {
// result = result + 1;
// }
TEST_F(WhileTest, WhileWithScalarResult) {
TEST_F(WhileTest, WhileWithScalarS32Result) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
@ -81,6 +81,43 @@ TEST_F(WhileTest, WhileWithScalarResult) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
// Tests a while node when the result type T is S64.
//
// int32 result = 0;
// while (result < 5) {
// result = result + 1;
// }
TEST_F(WhileTest, WhileWithScalarS64Result) {
auto result_shape = ShapeUtil::MakeShape(S64, {});
// Create a computation for the condition: repeat for 5 iterations.
Computation condition;
{
ComputationBuilder builder(client_, "condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int64>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
Computation body;
{
ComputationBuilder builder(client_, "body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int64>(1);
auto result = builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder builder(client_, TestName());
auto init = builder.ConstantR0<int64>(0);
auto result = builder.While(condition, body, init);
auto shape = builder.GetShape(result).ConsumeValueOrDie();
ComputeAndCompareR0<int64>(&builder, 5, {});
}
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
auto orig_shape = ShapeUtil::MakeShape(S32, {2});

View File

@ -33,6 +33,7 @@ py_library(
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
"//tensorflow/contrib/image:distort_image_py",
"//tensorflow/contrib/image:image_py",
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
"//tensorflow/contrib/imperative",

View File

@ -20,10 +20,10 @@ import android.os.Build.VERSION;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
@ -79,24 +79,32 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
final int numBytesRead = is.read(graphDef);
if (numBytesRead != graphDef.length) {
throw new IOException(
"read error: read only "
+ numBytesRead
+ " of the graph, expected to read "
+ graphDef.length);
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
@ -121,13 +129,13 @@ public class TensorFlowInferenceInterface {
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
int numBytesRead;
@ -143,7 +151,7 @@ public class TensorFlowInferenceInterface {
loadGraph(graphDef, g);
Log.i(TAG, "Successfully loaded model from the input stream");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
@ -309,8 +317,8 @@ public class TensorFlowInferenceInterface {
/**
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of
* bytes, not a Java {@code String} (which is a sequence of characters).
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
addFeed(inputName, Tensor.create(src));
@ -318,9 +326,8 @@ public class TensorFlowInferenceInterface {
/**
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string"
* is an arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of
* characters).
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
addFeed(inputName, Tensor.create(src));

View File

@ -151,7 +151,7 @@ def convert_to_universal_format(dtec, sorted_feature_names,
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
elif node_type == "sparse_float_binary_split_default_right":
split = gtflow_node.sparse_float_binary_split_default_right
split = gtflow_node.sparse_float_binary_split_default_right.split
node.default_direction = (
generic_tree_model_pb2.BinaryNode.RIGHT)
feature_id = split.feature_column + num_dense

View File

@ -329,7 +329,6 @@ add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests")
add_python_module("tensorflow/contrib/cudnn_rnn/python/ops")
add_python_module("tensorflow/contrib/data")
add_python_module("tensorflow/contrib/data/python")
add_python_module("tensorflow/contrib/data/python/framework")
add_python_module("tensorflow/contrib/data/python/kernel_tests")
add_python_module("tensorflow/contrib/data/python/ops")
add_python_module("tensorflow/contrib/data/python/util")
@ -362,6 +361,8 @@ add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/gan/python")
add_python_module("tensorflow/contrib/gan/python/eval")
add_python_module("tensorflow/contrib/gan/python/eval/python")
add_python_module("tensorflow/contrib/gan/python/features")
add_python_module("tensorflow/contrib/gan/python/features/python")
add_python_module("tensorflow/contrib/gan/python/losses")

View File

@ -20,6 +20,7 @@
@@FixedLengthRecordDataset
@@TextLineDataset
@@batch_and_drop_remainder
@@read_batch_features
@@rejection_resample
@@group_by_window
@ -32,6 +33,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.data.python.ops.dataset_ops import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window

View File

@ -1,48 +0,0 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "function",
srcs = ["function.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "function_test",
size = "medium",
srcs = ["function_test.py"],
srcs_version = "PY2AND3",
deps = [
":function",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -1,275 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An experimental fork of the Python TensorFlow-function library.
NOTE: functions are currently experimental and subject to change!
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import tf_inspect
# NOTE(mrry): This is an experimental extension of a core class that wasn't
# designed to be extended, so we disable protected access checks for the
# whole file.
# pylint: disable=protected-access
class _ExperimentalFuncGraph(function._FuncGraph):
"""A helper for construction a function (supporting capture-by-value).
_ExperimentalFuncGraph overrides ops.Graph's create_op() so that we can keep
track of every inputs into every op created inside the function. If
any input is from other graphs, we keep track of it in self.capture
and substitute the input with a place holder.
Each captured input's corresponding place holder is converted into a
function argument and the caller passes in the captured tensor.
"""
def __init__(self, capture_by_value, *args, **kwargs):
super(_ExperimentalFuncGraph, self).__init__(*args, **kwargs)
self._capture_by_value = capture_by_value
self._building_function = True
self._outer_graph = ops.get_default_graph()
self._vscope = vs.get_variable_scope()
self._old_custom_getter = self._vscope.custom_getter
self._captured = {}
self.extra_inputs = []
self.extra_args = []
self.extra_vars = []
def create_op(self, op_type, inputs, data_types, **kwargs):
for i, x in enumerate(inputs):
if x.graph is not self:
# Referring to a tensor from other graph.
if x in self._captured:
# Captured already.
inputs[i] = self._captured[x]
elif self._capture_by_value:
inputs[i] = self._add_tensor_and_parents(x)
else:
# Substitute with a placeholder.
self.extra_inputs.append(x)
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
# pylint: disable=protected-access
ph._handle_data = x._handle_data
# pylint: enable=protected-access
inputs[i] = ph
self._captured[x] = ph
self.extra_args.append(ph)
return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
data_types, **kwargs)
def _add_tensor_and_parents(self, tensor):
op = self._add_op_and_parents(tensor.op)
return op.outputs[tensor.value_index]
def _add_op_and_parents(self, op):
op_def = graph_to_function_def._get_op_def(op)
if op_def.is_stateful:
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
"by value." % (op.name, op.type))
elif op.type in ("Placeholder", "PlaceholderV2"):
raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
"by value." % (op.name, op.type))
captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
captured_op = self.create_op(op.type, captured_inputs,
[o.dtype for o in op.outputs],
name=op.name, attrs=op.node_def.attr,
op_def=op_def)
for t, captured_t in zip(op.outputs, captured_op.outputs):
self._captured[t] = captured_t
return captured_op
class _ExperimentalDefinedFunction(function._DefinedFunction):
"""Overrides _DefinedFunction with support for capture-by-value."""
def __init__(self,
func,
argnames,
input_types,
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
shape_func=None,
capture_by_value=False,
**kwargs):
"""Creates an _ExperimentalDefinedFunction.
Args:
func: A python callable which constructs a tf function body.
argnames: A list of strings for function argument names.
input_types: The function's argument types. Can be a tuple, list of
tf data types.
func_name: The function name. Defaults to None, in which derives from
'func'.
grad_func: This function's gradient function, if not None. Defaults
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: An optional list of strings for the function return value
names.
shape_func: An optional function mapping an op to a list of static
output shapes.
capture_by_value: Boolean (defaults to False). If True, captured values
will be copied into the function body.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
Raises:
ValueError: The function definition is invalid.
"""
super(_ExperimentalDefinedFunction, self).__init__(
func, argnames, input_types, func_name, grad_func, python_grad_func,
out_names, shape_func, **kwargs)
self._capture_by_value = capture_by_value
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
with context.graph_mode():
self._create_definition_if_needed_impl()
def _create_definition_if_needed_impl(self):
"""You're looking for _create_definition_if_needed(), not this."""
if self._definition is not None:
return
# Create the func_def object.
temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
for (argname, argtype) in self._args:
argholder = array_ops.placeholder(argtype, name=argname)
inputs.append(argholder)
# Call func and gather the output tensors.
with vs.variable_scope("", custom_getter=temp_graph.getvar):
outputs = self._func(*inputs)
# If func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
if any([_ is None for _ in outputs]):
raise ValueError("Function can not return None.")
# Ensures each output is a Tensor.
outputs = [ops.convert_to_tensor(_) for _ in outputs]
self._extra_inputs = temp_graph.extra_inputs
inputs.extend(temp_graph.extra_args)
self._sub_functions = temp_graph._functions
# Build the FunctionDef
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph, temp_graph.get_operations(), inputs, outputs,
out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def.
sig_pre_func_name = self._func_name or function._get_func_name(self._func)
kwargs_attr = function._parse_kwargs_as_attrs(
sig_pre_func_name, **self._extra_kwargs)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg,
self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([function._get_func_name(self._func),
self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
class Defun(function.Defun):
"""Experimental version of Defun supporting capture-by-value."""
def __init__(self, *input_types, **kwargs):
"""Create an experimental `Defun` decorator.
Args:
*input_types: A list of `tf.DType`
**kwargs: Optional keyword arguments (see `function.Defun`) plus:
capture_by_value - Boolean (defaults to False). If True, captured values
will be copied into the function body.
"""
super(Defun, self).__init__(*input_types, **kwargs)
def __call__(self, func):
# Various sanity checks on the callable func.
if not callable(func):
raise ValueError("func %s must be callable" % func)
# Func should not use kwargs and defaults.
argspec = tf_inspect.getargspec(func)
if argspec.keywords or argspec.defaults:
raise ValueError("Functions with argument defaults or keyword "
"arguments are not supported.")
# Computes how many arguments 'func' has.
min_args = len(argspec.args)
max_args = min_args
if argspec.varargs:
max_args = 1000000
argnames = argspec.args
if tf_inspect.ismethod(func):
# 1st argument is the "class" type.
min_args -= 1
argnames = argnames[1:]
if self._input_types:
# If Defun is given a list of types for the inputs, the number
# of input types should be compatible with 'func'.
num = len(self._input_types)
if num < min_args or num > max_args:
raise ValueError(
"The function has fewer arguments than the number of specified "
"input types.")
return _ExperimentalDefinedFunction(
func, argnames, self._input_types, self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names,
**self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0:
return _ExperimentalDefinedFunction(
func, [], [], self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names,
**self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called.
return function._OverloadedFunction(
func, argnames, self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names, **self._extra_kwargs)

View File

@ -1,59 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for experimental capture-by-value feature in TF functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.framework import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class FunctionTest(test.TestCase):
def testCaptureByValue(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant([[1.0]])
b = constant_op.constant([2.0])
# Foo() captures w and b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Foo(x):
# Plus() captures b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Plus(y):
return y + b
self.assertEqual(0, len(Plus.captured_inputs))
return Plus(math_ops.matmul(w, x))
y = Foo(constant_op.constant([[10.]]))
self.assertEqual(0, len(Foo.captured_inputs))
with self.test_session(graph=g):
self.assertAllEqual(y.eval(), [[12.0]])
if __name__ == "__main__":
test.main()

View File

@ -264,6 +264,7 @@ py_test(
srcs = ["resample_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
tags = ["noasan"],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:client_testlib",

View File

@ -333,6 +333,55 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testBatchAndDropRemainder(self):
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.apply(dataset_ops.batch_and_drop_remainder(batch_size))
.make_initializable_iterator())
next_element = iterator.get_next()
with self.test_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
for i in range(num_batches):
result = sess.run(next_element)
for component, result_component in zip(components, result):
for j in range(test_batch_size):
self.assertAllEqual(component[(i * test_batch_size + j)],
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testBatchAndDropRemainderShapeInference(self):
components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(
dtypes.int32, shape=[None]), array_ops.placeholder(
dtypes.int32, shape=[20, 30])))
# Test with a statically known batch size.
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.apply(dataset_ops.batch_and_drop_remainder(128)))
self.assertIs(None, dataset.output_shapes[0].ndims)
self.assertEqual([128], dataset.output_shapes[1][0].as_list())
self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
# Test with a dynamic batch size: the static shape will be unknown, because
# `batch_size` is a placeholder.
batch_size = array_ops.placeholder(dtypes.int64)
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.apply(dataset_ops.batch_and_drop_remainder(batch_size)))
self.assertIs(None, dataset.output_shapes[0].ndims)
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
if __name__ == "__main__":
test.main()

View File

@ -22,13 +22,17 @@ import threading
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.util import nest
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
class DatasetConstructorTest(test.TestCase):
@ -475,6 +479,75 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSplitPipelineFailsWithPlacementError(self):
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
dataset = dataset_ops.Dataset.from_tensors(0)
# Define a pipeline that attempts to use variables on two
# different devices.
#
# Initialize the variables before creating to iterator, to avoid the
# placement algorithm overriding the DT_RESOURCE colocation constraints.
with ops.device("/cpu:0"):
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_0.read_value())
sess.run(var_0.initializer)
with ops.device("/cpu:1"):
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_1.read_value())
sess.run(var_1.initializer)
iterator = dataset.make_initializable_iterator()
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Trying to access resource located in device"):
sess.run(iterator.initializer)
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
(array_ops.placeholder(dtypes.int32, shape=[None]),
array_ops.placeholder(dtypes.int32, shape=[20, 30])))
dataset = dataset_ops.Dataset.from_tensors(components)
i32 = dtypes.int32
test_cases = [((i32, i32, i32), None),
(((i32, i32), i32), None),
((i32, i32, i32), (None, None, None)),
((i32, i32, i32), ([17], [17], [20, 30]))]
for new_types, new_shape_lists in test_cases:
# pylint: disable=protected-access
new = dataset_ops._RestructuredDataset(
dataset, new_types, new_shape_lists)
# pylint: enable=protected-access
self.assertEqual(new_types, new.output_types)
if new_shape_lists is not None:
for expected_shape_list, shape in zip(
nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
if expected_shape_list is None:
self.assertIs(None, shape.ndims)
else:
self.assertEqual(expected_shape_list, shape.as_list())
fail_cases = [((i32, dtypes.int64, i32), None),
((i32, i32, i32, i32), None),
((i32, i32, i32), ((None, None), None)),
((i32, i32, i32), (None, None, None, None)),
((i32, i32, i32), (None, [None], [21, 30]))]
for new_types, new_shape_lists in fail_cases:
with self.assertRaises(ValueError):
# pylint: disable=protected-access
new = dataset_ops._RestructuredDataset(
dataset, new_types, new_shape_lists)
# pylint: enable=protected-access
if __name__ == "__main__":
test.main()

View File

@ -21,9 +21,11 @@ import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
from tensorflow.python.util import compat
@ -50,7 +52,7 @@ class ResampleTest(test.TestCase):
seed=27))
init_op = iterator.initializer
get_next = iterator.get_next()
variable_init_op = variables.global_variables_initializer()
variable_init_op = variables.local_variables_initializer()
with self.test_session() as sess:
sess.run(variable_init_op)
@ -74,6 +76,22 @@ class ResampleTest(test.TestCase):
returned_dist = class_counts / total_returned
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
def testVariableDevicePlacement(self):
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
with ops.device(
device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
dataset = (dataset_ops.Dataset.from_tensor_slices(classes)
.shuffle(200, seed=21)
.map(lambda c: (c, string_ops.as_string(c))))
dataset = dataset_ops.rejection_resample(
dataset, target_dist=target_dist, initial_dist=None,
class_func=lambda c, _: c, seed=27)
self.assertEqual(1, len(variables.local_variables()))
self.assertEqual(b"",
compat.as_bytes(variables.local_variables()[0].device))
if __name__ == "__main__":
test.main()

View File

@ -9,7 +9,6 @@ py_library(
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@ -17,6 +16,7 @@ py_library(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:logging_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
@ -38,11 +38,11 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
"//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:platform",
],
)

View File

@ -23,10 +23,10 @@ import threading
import numpy as np
from tensorflow.contrib.data.python.framework import function
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@ -99,8 +99,9 @@ class Iterator(object):
shared_name=shared_name,
output_types=nest.flatten(dataset.output_types),
output_shapes=nest.flatten(dataset.output_shapes))
initializer = gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
iterator_resource)
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(
dataset.make_dataset_resource(), iterator_resource)
return Iterator(iterator_resource, initializer, dataset.output_types,
dataset.output_shapes)
@ -291,6 +292,7 @@ class Iterator(object):
raise TypeError("Expected output shapes compatible with %r but got "
"dataset with output shapes %r." %
(self._output_shapes, dataset.output_shapes))
with ops.colocate_with(self._iterator_resource):
return gen_dataset_ops.make_iterator(
dataset.make_dataset_resource(), self._iterator_resource, name=name)
@ -2404,12 +2406,16 @@ def rejection_resample(dataset,
num_classes = (target_dist.shape[0].value or
array_ops.shape(target_dist)[0])
smoothing_constant = 10
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
initial_value=array_ops.fill([num_classes],
np.int64(smoothing_constant)),
trainable=False,
name="class_count",
dtype=dtypes.int64)
# Disable device functions and colocation constraints so that the variable
# will be placed with the eventual DT_VARIANT dataset tensor.
with ops.colocate_with(None, ignore_existing=True):
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
initial_value=array_ops.fill([num_classes],
np.int64(smoothing_constant)),
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="local_class_count",
dtype=dtypes.int64)
def update_estimate_and_tile(c):
return array_ops.tile(
@ -2519,7 +2525,13 @@ def read_batch_features(file_pattern,
dataset = reader(filenames, *reader_args)
else:
dataset = reader(filenames)
dataset = dataset.repeat(num_epochs)
if dataset.output_types == (dtypes.string, dtypes.string):
dataset = dataset.map(lambda unused_k, v: v)
elif dataset.output_types != dtypes.string:
raise TypeError("`reader` must be a dataset of `tf.string` values, "
"or `(tf.string, tf.string)` key-value pairs.")
if num_epochs != 1:
dataset = dataset.repeat(num_epochs)
if randomize_input:
dataset = dataset.shuffle(capacity)
dataset = dataset.batch(batch_size)
@ -2729,3 +2741,137 @@ def group_by_window(dataset,
assert window_size_func is not None
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)
class _RestructuredDataset(Dataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self, dataset, output_types, output_shapes=None):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
* `dataset.output_types` must be the same as `output_types` module nesting.
* Each shape in `dataset.output_shapes` must be compatible with each shape
in `output_shapes` (if given).
Note: This helper permits "unsafe casts" for shapes, equivalent to using
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
Args:
dataset: A `Dataset` object.
output_types: A nested structure of `tf.DType` objects.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
If omitted, the shapes will be inherited from `dataset`.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
super(_RestructuredDataset, self).__init__()
self._dataset = dataset
# Validate that the types are compatible.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
flat_original_types = nest.flatten(dataset.output_types)
flat_new_types = nest.flatten(output_types)
if flat_original_types != flat_new_types:
raise ValueError(
"Dataset with output types %r cannot be restructured to have output "
"types %r" % (dataset.output_types, output_types))
self._output_types = output_types
if output_shapes is None:
# Inherit shapes from the original `dataset`.
self._output_shapes = nest.pack_sequence_as(
output_types, nest.flatten(dataset.output_shapes))
else:
# Validate that the shapes are compatible.
nest.assert_same_structure(output_types, output_shapes)
flat_original_shapes = nest.flatten(dataset.output_shapes)
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
for original_shape, new_shape in zip(flat_original_shapes,
flat_new_shapes):
if not original_shape.is_compatible_with(new_shape):
raise ValueError(
"Dataset with output shapes %r cannot be restructured to have "
"incompatible output shapes %r"
% (dataset.output_shapes, output_shapes))
self._output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
def make_dataset_resource(self):
return self._dataset.make_dataset_resource()
@property
def output_types(self):
return self._output_types
@property
def output_shapes(self):
return self._output_shapes
def batch_and_drop_remainder(batch_size):
"""A batching transformation that omits the final small batch (if present).
Like @{tf.contrib.data.Dataset.batch}, this transformation combines
consecutive elements of this dataset into batches. However, if the batch
size does not evenly divide the input dataset size, this transformation will
drop the final smaller element.
The following example illustrates the difference between this
transformation and `Dataset.batch()`:
```python
dataset = tf.contrib.data.Dataset.range(200)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))
print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known)
```
By contrast, `dataset.batch(128)` would yield a two-element dataset with
shapes `(128,)` and `(72,)`, so the batch dimension would not be statically
known.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.contrib.data.Dataset.apply}
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
tensor_batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
batched = dataset.batch(tensor_batch_size)
flattened = _RestructuredDataset(batched,
tuple(nest.flatten(batched.output_types)))
def _predicate(*xs):
"""Return `True` if this element is a full batch."""
# Extract the dynamic batch size from the first component of the flattened
# batched element.
first_component = xs[0]
first_component_batch_size = array_ops.shape(
first_component, out_type=dtypes.int64)[0]
return math_ops.equal(first_component_batch_size, tensor_batch_size)
filtered = flattened.filter(_predicate)
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
def _set_first_dimension(shape):
return shape.merge_with(
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
known_shapes = nest.map_structure(_set_first_dimension,
batched.output_shapes)
return _RestructuredDataset(filtered, batched.output_types, known_shapes)
return _apply_fn

View File

@ -17,10 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.framework import function
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops

View File

@ -30,6 +30,7 @@ cuda_py_test(
":tfe",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:test",
],
)

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.eager.python import tfe
from tensorflow.python.platform import test
from tensorflow.python.eager import test
class TFETest(test.TestCase):
@ -31,6 +31,14 @@ class TFETest(test.TestCase):
devices = tfe.list_devices()
self.assertEqual(len(devices) - 1, tfe.num_gpus())
def testCallingEnableEagerExecutionMoreThanOnce(self):
# Note that eager.test.main() has already invoked enable_eager_exceution().
with self.assertRaisesRegexp(
ValueError,
r"Do not call tfe\.%s more than once in the same process" %
tfe.enable_eager_execution.__name__):
tfe.enable_eager_execution()
if __name__ == "__main__":
test.main()

View File

@ -25,11 +25,46 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":dnn",
":extenders",
":head",
],
)
py_library(
name = "dnn",
srcs = ["python/estimator/dnn.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:nn",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:dnn",
],
)
py_test(
name = "dnn_test",
size = "small",
srcs = ["python/estimator/dnn_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dnn",
":head",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
"//tensorflow/python/estimator:dnn_testing_utils",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/feature_column",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "extenders",
srcs = [
@ -68,6 +103,37 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:summary",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/ops/losses",
],
)
py_test(
name = "head_test",
size = "small",
srcs = ["python/estimator/head_test.py"],
srcs_version = "PY2AND3",
deps = [
":head",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@six_archive//:six",
],
)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
@ -29,7 +30,9 @@ _allowed_symbols = [
'add_metrics',
'binary_classification_head',
'multi_class_head',
'multi_label_head',
'regression_head',
'DNNEstimator',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -0,0 +1,134 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deep Neural Network estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import dnn as dnn_lib
from tensorflow.python.ops import nn
class DNNEstimator(estimator.Estimator):
"""An estimator for TensorFlow DNN models with user-specified head.
Example:
```python
sparse_feature_a = sparse_column_with_hash_bucket(...)
sparse_feature_b = sparse_column_with_hash_bucket(...)
sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
...)
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
...)
estimator = DNNEstimator(
head=tf.contrib.estimator.multi_label_head(n_classes=3),
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
hidden_units=[1024, 512, 256])
# Or estimator using the ProximalAdagradOptimizer optimizer with
# regularization.
estimator = DNNEstimator(
head=tf.contrib.estimator.multi_label_head(n_classes=3),
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
hidden_units=[1024, 512, 256],
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
))
# Input builders
def input_fn_train: # returns x, y
pass
estimator.train(input_fn=input_fn_train, steps=100)
def input_fn_eval: # returns x, y
pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
pass
predictions = estimator.predict(input_fn=input_fn_predict)
```
Input of `train` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `_WeightedCategoricalColumn`, two features: the first
with `key` the id column name, the second with `key` the weight column
name. Both features' `value` must be a `SparseTensor`.
- if `column` is a `_DenseColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
Loss and predicted output are determined by the specified head.
"""
def __init__(self,
head,
hidden_units,
feature_columns,
model_dir=None,
optimizer='Adagrad',
activation_fn=nn.relu,
dropout=None,
input_layer_partitioner=None,
config=None):
"""Initializes a `DNNClassifier` instance.
Args:
head: A `_Head` instance constructed with a method such as
`tf.contrib.estimator.multi_label_head`.
hidden_units: Iterable of number hidden units per layer. All layers are
fully connected. Ex. `[64, 32]` means first layer has 64 nodes and
second one has 32.
feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived
from `_FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
coordinate.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
"""
def _model_fn(features, labels, mode, config):
return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
features=features,
labels=labels,
mode=mode,
head=head,
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config)
super(DNNEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)

View File

@ -0,0 +1,153 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dnn.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
import tempfile
import numpy as np
import six
from tensorflow.contrib.estimator.python.estimator import dnn
from tensorflow.contrib.estimator.python.estimator import head as head_lib
from tensorflow.python.estimator.canned import dnn_testing_utils
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
"""Returns a DNNEstimator that uses regression_head."""
return dnn.DNNEstimator(
head=head_lib.regression_head(
weight_column=weight_column, label_dimension=label_dimension),
*args, **kwargs)
class DNNEstimatorEvaluateTest(
dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
self, _dnn_estimator_fn)
class DNNEstimatorPredictTest(
dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
self, _dnn_estimator_fn)
class DNNEstimatorTrainTest(
dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
self, _dnn_estimator_fn)
class DNNEstimatorIntegrationTest(test.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
def tearDown(self):
if self._model_dir:
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
label_dimension, batch_size):
feature_columns = [
feature_column.numeric_column('x', shape=(input_dimension,))]
est = dnn.DNNEstimator(
head=head_lib.regression_head(label_dimension=label_dimension),
hidden_units=(2, 2),
feature_columns=feature_columns,
model_dir=self._model_dir)
# TRAIN
num_steps = 10
est.train(train_input_fn, steps=num_steps)
# EVALUTE
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
# PREDICT
predictions = np.array([
x[prediction_keys.PredictionKeys.PREDICTIONS]
for x in est.predict(predict_input_fn)
])
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
feature_spec = feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
data = data.reshape(batch_size, label_dimension)
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
batch_size=batch_size,
shuffle=False)
self._test_complete_flow(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
batch_size=batch_size)
if __name__ == '__main__':
test.main()

View File

@ -18,7 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
def multi_class_head(n_classes,
@ -33,7 +46,7 @@ def multi_class_head(n_classes,
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
`binary_classification_head`).
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
@ -123,3 +136,206 @@ def regression_head(weight_column=None,
weight_column=weight_column,
label_dimension=label_dimension,
head_name=head_name)
# TODO(roumposg): Support label_vocabulary.
def multi_label_head(n_classes,
weight_column=None,
thresholds=None,
head_name=None):
"""Creates a `_Head` for multi-label classification.
Multi-label classification handles the case where each example may have zero
or more associated labels, from a discrete set. This is distinct from
`multi_class_head` which has exactly one label per example.
Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
`SparseTensor` of class indices.
Args:
n_classes: Number of classes, must be greater than 1 (for 1 class, use
`binary_classification_head`).
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
and recall metrics are evaluated for each threshold value. The threshold
is applied to the predicted probabilities, i.e. above the threshold is
`true`, below is `false`.
head_name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + head_name`.
Returns:
An instance of `_Head` for multi-label classification.
Raises:
ValueError: if `n_classes` or `thresholds` is invalid.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if n_classes is None or n_classes < 2:
raise ValueError(
'n_classes must be > 1 for multi-class classification. '
'Given: {}'.format(n_classes))
for threshold in thresholds:
if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError(
'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
return _MultiLabelHead(
n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
head_name=head_name)
class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
"""`_Head` for multi-label classification."""
def __init__(self,
n_classes,
weight_column=None,
thresholds=None,
head_name=None):
self._n_classes = n_classes
self._weight_column = weight_column
self._thresholds = thresholds
self._head_name = head_name
@property
def logits_dimension(self):
return self._n_classes
def _process_labels(self, labels):
if isinstance(labels, sparse_tensor.SparseTensor):
return math_ops.to_int64(
sparse_ops.sparse_to_indicator(labels, self._n_classes))
msg = ('labels shape must be [batch_size, {}]. '
'Given: ').format(self._n_classes)
labels_shape = array_ops.shape(labels)
check_rank_op = control_flow_ops.Assert(
math_ops.equal(array_ops.rank(labels), 2),
data=[msg, labels_shape])
check_label_dim = control_flow_ops.Assert(
math_ops.equal(labels_shape[-1], self._n_classes),
data=[msg, labels_shape])
with ops.control_dependencies([check_rank_op, check_label_dim]):
return array_ops.identity(labels)
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode, features # Unused for this head.
processed_labels = self._process_labels(labels)
unweighted_loss = losses.sigmoid_cross_entropy(
multi_class_labels=processed_labels, logits=logits,
reduction=losses.Reduction.NONE)
return head_lib.LossAndLabels(
unweighted_loss=unweighted_loss,
processed_labels=processed_labels)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
with ops.name_scope('head'):
logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access
# Predict.
pred_keys = prediction_keys.PredictionKeys
with ops.name_scope(None, 'predictions', (logits,)):
probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
predictions = {
pred_keys.LOGITS: logits,
pred_keys.PROBABILITIES: probabilities,
}
if mode == model_fn.ModeKeys.PREDICT:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'': export_output.ClassificationOutput(scores=probabilities)
})
# Eval.
unweighted_loss, _ = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
# Averages loss over classes.
per_example_loss = math_ops.reduce_mean(
unweighted_loss, axis=-1, keep_dims=True)
weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access
training_loss = losses.compute_weighted_loss(
per_example_loss, weights=weights, reduction=losses.Reduction.SUM)
if mode == model_fn.ModeKeys.EVAL:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=training_loss,
eval_metric_ops=self._eval_metric_ops(
labels=labels,
probabilities=probabilities,
weights=weights,
per_example_loss=per_example_loss))
# Train.
if train_op_fn is None:
raise ValueError('train_op_fn can not be None.')
with ops.name_scope(''):
summary.scalar(
head_lib._summary_key(self._head_name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access
training_loss)
summary.scalar(
head_lib._summary_key( # pylint:disable=protected-access
self._head_name, metric_keys.MetricKeys.LOSS_MEAN),
losses.compute_weighted_loss(
unweighted_loss, weights=weights,
reduction=losses.Reduction.MEAN))
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=training_loss,
train_op=train_op_fn(training_loss))
def _eval_metric_ops(self, labels, probabilities, weights, per_example_loss):
"""Returns a dict of metrics for eval_metric_ops."""
with ops.name_scope(
None, 'metrics', [labels, probabilities, weights, per_example_loss]):
keys = metric_keys.MetricKeys
metric_ops = {
# Estimator already adds a metric for loss.
head_lib._summary_key(self._head_name, keys.LOSS_MEAN): # pylint:disable=protected-access
metrics_lib.mean(
per_example_loss, weights=weights, name=keys.LOSS_MEAN),
head_lib._summary_key(self._head_name, keys.AUC): # pylint:disable=protected-access
metrics_lib.auc(
labels=labels, predictions=probabilities, weights=weights,
name=keys.AUC),
head_lib._summary_key(self._head_name, keys.AUC_PR): # pylint:disable=protected-access
metrics_lib.auc(
labels=labels, predictions=probabilities, weights=weights,
curve='PR', name=keys.AUC_PR),
}
for threshold in self._thresholds:
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
metric_ops[head_lib._summary_key(self._head_name, accuracy_key)] = ( # pylint:disable=protected-access
head_lib._accuracy_at_threshold( # pylint:disable=protected-access
labels=labels,
predictions=probabilities,
weights=weights,
threshold=threshold,
name=accuracy_key))
# Precision for positive examples.
precision_key = keys.PRECISION_AT_THRESHOLD % threshold
metric_ops[head_lib._summary_key(self._head_name, precision_key)] = ( # pylint:disable=protected-access
head_lib._precision_at_threshold( # pylint:disable=protected-access
labels=labels,
predictions=probabilities,
weights=weights,
threshold=threshold,
name=precision_key))
# Recall for positive examples.
recall_key = keys.RECALL_AT_THRESHOLD % threshold
metric_ops[head_lib._summary_key(self._head_name, recall_key)] = ( # pylint:disable=protected-access
head_lib._recall_at_threshold( # pylint:disable=protected-access
labels=labels,
predictions=probabilities,
weights=weights,
threshold=threshold,
name=recall_key))
return metric_ops

View File

@ -0,0 +1,570 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for head."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib.estimator.python.estimator import head as head_lib
from tensorflow.core.framework import summary_pb2
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import monitored_session
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
def _initialize_variables(test_case, scaffold):
scaffold.finalize()
test_case.assertIsNone(scaffold.init_feed_dict)
test_case.assertIsNone(scaffold.init_fn)
scaffold.init_op.run()
scaffold.ready_for_local_init_op.eval()
scaffold.local_init_op.run()
scaffold.ready_op.eval()
test_case.assertIsNotNone(scaffold.saver)
def _assert_simple_summaries(test_case, expected_summaries, summary_str,
tol=1e-6):
"""Assert summary the specified simple values.
Args:
test_case: test case.
expected_summaries: Dict of expected tags and simple values.
summary_str: Serialized `summary_pb2.Summary`.
tol: Tolerance for relative and absolute.
"""
summary = summary_pb2.Summary()
summary.ParseFromString(summary_str)
test_case.assertAllClose(expected_summaries, {
v.tag: v.simple_value for v in summary.value
}, rtol=tol, atol=tol)
def _assert_no_hooks(test_case, spec):
test_case.assertAllEqual([], spec.training_chief_hooks)
test_case.assertAllEqual([], spec.training_hooks)
def _sigmoid(logits):
return 1 / (1 + np.exp(-logits))
def _sigmoid_cross_entropy(labels, logits):
sigmoid_logits = _sigmoid(logits)
return (-labels * np.log(sigmoid_logits)
-(1 - labels) * np.log(1 - sigmoid_logits))
class MultiLabelHead(test.TestCase):
def setUp(self):
ops.reset_default_graph()
def test_n_classes_is_none(self):
with self.assertRaisesRegexp(
ValueError,
r'n_classes must be > 1 for multi-class classification\. Given: None'):
head_lib.multi_label_head(n_classes=None)
def test_n_classes_is_1(self):
with self.assertRaisesRegexp(
ValueError,
r'n_classes must be > 1 for multi-class classification\. Given: 1'):
head_lib.multi_label_head(n_classes=1)
def test_threshold_too_small(self):
with self.assertRaisesRegexp(
ValueError,
r'thresholds must be in \(0, 1\) range\. Given: 0\.0'):
head_lib.multi_label_head(n_classes=2, thresholds=[0., 0.5])
def test_threshold_too_large(self):
with self.assertRaisesRegexp(
ValueError,
r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
def test_predict(self):
n_classes = 4
head = head_lib.multi_label_head(n_classes)
self.assertEqual(n_classes, head.logits_dimension)
logits = np.array(
[[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
expected_probabilities = _sigmoid(logits)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
self.assertItemsEqual(
('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys())
# Assert predictions and export_outputs.
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
expected_probabilities,
predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose(
expected_probabilities,
sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
def test_weight_should_not_impact_prediction(self):
n_classes = 4
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
self.assertEqual(n_classes, head.logits_dimension)
logits = np.array(
[[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
expected_probabilities = _sigmoid(logits)
weights_2x1 = [[1.], [2.]]
spec = head.create_estimator_spec(
features={
'x': np.array(((42,),), dtype=np.int32),
'label_weights': weights_2x1,
},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
# Assert predictions and export_outputs.
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
expected_probabilities,
predictions[prediction_keys.PredictionKeys.PROBABILITIES])
def test_eval_create_loss(self):
"""Tests head.create_loss for eval mode."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
expected_unweighted_loss = _sigmoid_cross_entropy(
labels=labels, logits=logits)
actual_unweighted_loss, _ = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval())
def test_eval_create_loss_large_logits(self):
"""Tests head.create_loss for eval mode and large logits."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
expected_unweighted_loss = np.array(
[[10., 10.], [15., 0.]], dtype=np.float32)
actual_unweighted_loss, _ = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_eval_create_loss_sparse_labels(self):
"""Tests head.create_loss for eval mode and sparse labels."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = sparse_tensor.SparseTensor(
values=[0, 0, 1],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
expected_unweighted_loss = np.array(
[[10., 10.], [15., 0.]], dtype=np.float32)
actual_unweighted_loss, actual_labels = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllEqual(expected_labels, actual_labels.eval())
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_eval_create_loss_labels_wrong_shape(self):
"""Tests head.create_loss for eval mode when labels has the wrong shape."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
labels_placeholder = array_ops.placeholder(dtype=dtypes.int64)
actual_unweighted_loss, _ = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels_placeholder)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'):
actual_unweighted_loss.eval(
{labels_placeholder: np.array([[1], [1]], dtype=np.int64)})
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'):
actual_unweighted_loss.eval(
{labels_placeholder: np.array([1, 1], dtype=np.int64)})
def test_eval(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
}
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
rtol=tol,
atol=tol)
def test_eval_with_thresholds(self):
n_classes = 2
thresholds = [0.25, 0.5, 0.75]
head = head_lib.multi_label_head(n_classes, thresholds=thresholds)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4.,
keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2.,
keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3.,
keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4.,
keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1.,
keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
}
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
rtol=tol,
atol=tol)
def test_eval_with_weights(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, weighted sum over examples.
expected_loss = 25.
spec = head.create_estimator_spec(
features={
'x': np.array([[41], [42]], dtype=np.int32),
'label_weights': np.array([[1.], [2.]], dtype=np.float32),
},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over weighted examples.
keys.LOSS_MEAN: expected_loss / 3,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.2000,
keys.AUC_PR: 0.7833,
}
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
rtol=tol,
atol=tol)
def test_train_create_loss_large_logits(self):
"""Tests head.create_loss for train mode and large logits."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
expected_unweighted_loss = np.array(
[[10., 10.], [15., 0.]], dtype=np.float32)
actual_unweighted_loss, _ = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_train(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, sum over weights.
expected_loss = 17.5
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
[constant_op.constant(expected_train_result),
string_ops.as_string(loss, precision=3)])
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNotNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
spec.scaffold.summary_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
_assert_simple_summaries(self, {
metric_keys.MetricKeys.LOSS: expected_loss,
# Average loss over examples.
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
}, summary_str, tol)
def test_train_with_weights(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, weighted sum over examples.
expected_loss = 25.
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
[constant_op.constant(expected_train_result),
string_ops.as_string(loss, precision=3)])
spec = head.create_estimator_spec(
features={
'x': np.array([[41], [42]], dtype=np.int32),
'label_weights': np.array([[1.], [2.]], dtype=np.float32),
},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNotNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
spec.scaffold.summary_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
_assert_simple_summaries(self, {
metric_keys.MetricKeys.LOSS: expected_loss,
# Average loss over weighted examples.
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
}, summary_str, tol)
if __name__ == '__main__':
test.main()

View File

@ -221,6 +221,7 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
"notsan",
],
)

View File

@ -14,10 +14,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":eval",
":features",
":losses",
":namedtuples",
":train",
"//tensorflow/python:util",
],
)
@ -73,6 +75,18 @@ py_test(
],
)
py_library(
name = "eval",
srcs = ["python/eval/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":classifier_metrics",
":eval_utils",
":summaries",
"//tensorflow/python:util",
],
)
py_library(
name = "losses",
srcs = ["python/losses/__init__.py"],
@ -257,6 +271,105 @@ py_test(
],
)
py_library(
name = "classifier_metrics",
srcs = [
"python/eval/python/classifier_metrics.py",
"python/eval/python/classifier_metrics_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:functional_ops",
"//tensorflow/python:image_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
py_test(
name = "classifier_metrics_test",
srcs = ["python/eval/python/classifier_metrics_test.py"],
srcs_version = "PY2AND3",
deps = [
":classifier_metrics",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_library(
name = "eval_utils",
srcs = [
"python/eval/python/eval_utils.py",
"python/eval/python/eval_utils_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
],
)
py_test(
name = "eval_utils_test",
srcs = ["python/eval/python/eval_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":eval_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "summaries",
srcs = [
"python/eval/python/summaries.py",
"python/eval/python/summaries_impl.py",
],
srcs_version = "PY2AND3",
deps = [
":eval_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:summary",
"//tensorflow/python:util",
"//tensorflow/python/ops/losses",
],
)
py_test(
name = "summaries_test",
srcs = ["python/eval/python/summaries_test.py"],
srcs_version = "PY2AND3",
deps = [
":namedtuples",
":summaries",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:summary",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
from tensorflow.contrib.gan.python import namedtuples
@ -32,6 +33,7 @@ from tensorflow.contrib.gan.python.train import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'eval',
'features',
'losses',
]

View File

@ -0,0 +1,39 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
# pylint: disable=,wildcard-import,unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse eval into a single namespace.
from tensorflow.contrib.gan.python.eval.python import classifier_metrics
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.contrib.gan.python.eval.python import summaries
from tensorflow.contrib.gan.python.eval.python.classifier_metrics import *
from tensorflow.contrib.gan.python.eval.python.eval_utils import *
from tensorflow.contrib.gan.python.eval.python.summaries import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'classifier_metrics',
'summaries',
'eval_utils',
] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model evaluation tools for TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.eval.python.classifier_metrics_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = classifier_metrics_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,401 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model evaluation tools for TFGAN.
These methods come from https://arxiv.org/abs/1606.03498 and
https://arxiv.org/abs/1706.08500.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
import tarfile
from six.moves import urllib
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import gfile
__all__ = [
'get_graph_def_from_disk',
'preprocess_image',
'run_image_classifier',
'run_inception',
'inception_score',
'classifier_score',
'frechet_inception_distance',
'frechet_classifier_distance',
]
INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz'
INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb'
INCEPTION_V3_INPUT = 'inputs'
INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0'
INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0'
_INCEPTION_V3_NUM_CLASSES = 1001
_INCEPTION_V3_FINAL_POOL_SIZE = 2048
INCEPTION_V3_DEFAULT_IMG_SIZE = 299
def _validate_images(images, image_size):
images = ops.convert_to_tensor(images)
images.shape.with_rank(4)
images.shape.assert_is_compatible_with(
[None, image_size, image_size, None])
return images
def _matrix_square_root(mat, eps=1e-10):
"""Compute symmetric square root of matrix.
Equivalent to matrix square root when matrix is invertible; note that this is
different from an elementwise square root. We want to compute M' where M' =
sqrt(mat) such that M' * M' = mat.
Args:
mat: Matrix to take the square root of.
eps: Small epsilon such that any element less than eps will not be square
rooted to guard against numerical instability.
Returns:
Matrix square root of mat.
"""
s, u, v = linalg_ops.svd(mat)
# sqrt is unstable around 0, just use 0 in such case
si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
return math_ops.matmul(
math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
# Convenience preprocessing function, with fixed defaults.
# NOTE: Floating-point inputs are expected to be in [0, 1].
# Copied from /tensorflow_models/slim/preprocessing/inception_preprocessing.py.
def preprocess_image(
image, height=INCEPTION_V3_DEFAULT_IMG_SIZE,
width=INCEPTION_V3_DEFAULT_IMG_SIZE, central_fraction=0.875, scope=None):
"""Prepare one image for evaluation.
If height and width are specified it would output an image with that size by
applying resize_bilinear.
If central_fraction is specified it would crop the central fraction of the
input image.
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
height: integer
width: integer
central_fraction: Optional Float, fraction of the image to crop.
scope: Optional scope for name_scope.
Returns:
3-D float Tensor of prepared image.
"""
with ops.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != dtypes.float32:
image = image_ops.convert_image_dtype(image, dtype=dtypes.float32)
# Crop the central region of the image with an area containing 87.5% of
# the original image.
image = image_ops.central_crop(image, central_fraction=central_fraction)
# Resize the image to the specified height and width.
image = array_ops.expand_dims(image, 0)
image = image_ops.resize_bilinear(image, [height, width],
align_corners=False)
image = array_ops.squeeze(image, [0])
image = (image - 0.5) * 2.0
return image
def _kl_divergence(p, p_logits, q):
"""Computes the Kullback-Liebler divergence between p and q.
This function uses p's logits in some places to improve numerical stability.
Specifically:
KL(p || q) = sum[ p * log(p / q) ]
= sum[ p * ( log(p) - log(q) ) ]
= sum[ p * ( log_softmax(p_logits) - log(q) ) ]
Args:
p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
example and `j` corresponds to the probability of being in class `j`.
p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
q: A 1-D floating-point Tensor, where q_j corresponds to the probability
of class `j`.
Returns:
KL divergence between two distributions. Output dimension is 1D, one entry
per distribution in `p`.
Raises:
ValueError: If any of the inputs aren't floating-point.
ValueError: If p or p_logits aren't 2D.
ValueError: If q isn't 1D.
"""
for tensor in [p, p_logits, q]:
if not tensor.dtype.is_floating:
raise ValueError('Input %s must be floating type.', tensor.name)
p.shape.assert_has_rank(2)
p_logits.shape.assert_has_rank(2)
q.shape.assert_has_rank(1)
return math_ops.reduce_sum(
p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
def get_graph_def_from_disk(filename):
"""Get a GraphDef proto from a disk location."""
with gfile.FastGFile(filename, 'rb') as f:
return graph_pb2.GraphDef.FromString(f.read())
def get_graph_def_from_url_tarball(url, filename):
"""Get a GraphDef proto from a tarball on the web."""
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
url, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress)
with tarfile.open(tar_filename, 'r:gz') as tar:
proto_str = tar.extractfile(filename).read()
return graph_pb2.GraphDef.FromString(proto_str)
def _default_graph_def_fn():
return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH)
def run_inception(images,
graph_def=None,
default_graph_def_fn=_default_graph_def_fn,
image_size=INCEPTION_V3_DEFAULT_IMG_SIZE,
input_tensor=INCEPTION_V3_INPUT,
output_tensor=INCEPTION_V3_OUTPUT):
"""Run images through a pretrained Inception classifier.
Args:
images: Input tensors. Must be [batch, height, width, channels]. Input shape
and values must be in [-1, 1], which can be achieved using
`preprocess_image`.
graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
call `default_graph_def_fn` to get GraphDef.
default_graph_def_fn: A function that returns a GraphDef. Used if
`graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
image_size: Required image width and height. See unit tests for the default
values.
input_tensor: Name of input Tensor.
output_tensor: Name of output Tensor. This function will compute activations
at the specified layer. Examples include INCEPTION_V3_OUTPUT and
INCEPTION_V3_FINAL_POOL which would result in this function computing
the final logits or the penultimate pooling layer.
Returns:
Logits.
Raises:
ValueError: If images are not the correct size.
ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
"""
images = _validate_images(images, image_size)
if graph_def is None:
if default_graph_def_fn is None:
raise ValueError('If `graph_def` is `None`, must provide '
'`default_graph_def_fn`.')
graph_def = default_graph_def_fn()
activations = run_image_classifier(images, graph_def, input_tensor,
output_tensor)
if array_ops.rank(activations) != 2:
activations = layers.flatten(activations)
return activations
def run_image_classifier(tensor, graph_def, input_tensor,
output_tensor, scope='RunClassifier'):
"""Runs a network from a frozen graph.
Args:
tensor: An Input tensor.
graph_def: A GraphDef proto.
input_tensor: Name of input tensor in graph def.
output_tensor: Name of output tensor in graph def.
scope: Name scope for classifier.
Returns:
Classifier output. Shape depends on the classifier used, but is often
[batch, classes].
Raises:
ValueError: If `image_size` is not `None`, and `tensor` are not the correct
size.
"""
input_map = {input_tensor: tensor}
return_elements = [output_tensor]
classifier_output = importer.import_graph_def(
graph_def, input_map, return_elements, name=scope)[0]
return classifier_output
def classifier_score(images, classifier_fn, num_batches=1):
"""Classifier score for evaluating a conditional generative model.
This is based on the Inception Score, but for an arbitrary classifier.
This technique is described in detail in https://arxiv.org/abs/1606.03498. In
summary, this function calculates
exp( E[ KL(p(y|x) || p(y)) ] )
which captures how different the network's classification prediction is from
the prior distribution over classes.
Args:
images: Images to calculate the classifier score for.
classifier_fn: A function that takes images and produces logits based on a
classifier.
num_batches: Number of batches to split `generated_images` in to in order to
efficiently run them through the classifier network.
Returns:
The classifier score. A floating-point scalar.
"""
generated_images_list = array_ops.split(
images, num_or_size_splits=num_batches)
# Compute the classifier splits using the memory-efficient `map_fn`.
logits = functional_ops.map_fn(
fn=classifier_fn,
elems=array_ops.stack(generated_images_list),
parallel_iterations=1,
back_prop=False,
swap_memory=True,
name='RunClassifier')
logits = array_ops.concat(array_ops.unstack(logits), 0)
logits.shape.assert_has_rank(2)
p = nn_ops.softmax(logits)
q = math_ops.reduce_mean(p, axis=0)
kl = _kl_divergence(p, logits, q)
kl.shape.assert_has_rank(1)
log_score = math_ops.reduce_mean(kl)
return math_ops.exp(log_score)
inception_score = functools.partial(
classifier_score,
classifier_fn=functools.partial(
run_inception, output_tensor=INCEPTION_V3_OUTPUT))
def frechet_classifier_distance(real_images,
generated_images,
classifier_fn,
num_batches=1):
"""Classifier distance for evaluating a conditional generative model.
This is based on the Frechet Inception distance, but for an arbitrary
classifier.
This technique is described in detail in https://arxiv.org/abs/1706.08500.
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
Inception score, this is a true distance and utilizes information about real
world images.
Args:
real_images: Real images to use to compute Frechet Inception distance.
generated_images: Generated images to use to compute Frechet Inception
distance.
classifier_fn: A function that takes images and produces activations
based on a classifier.
num_batches: Number of batches to split images in to in order to
efficiently run them through the classifier network.
Returns:
The Frechet Inception distance. A floating-point scalar.
"""
real_images_list = array_ops.split(
real_images, num_or_size_splits=num_batches)
generated_images_list = array_ops.split(
generated_images, num_or_size_splits=num_batches)
imgs = array_ops.stack(real_images_list + generated_images_list)
# Compute the activations using the memory-efficient `map_fn`.
activations = functional_ops.map_fn(
fn=classifier_fn,
elems=imgs,
parallel_iterations=1,
back_prop=False,
swap_memory=True,
name='RunClassifier')
# Split the activations by the real and generated images.
real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
# Ensure the activations have the right shapes.
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
real_a.shape.assert_has_rank(2)
gen_a.shape.assert_has_rank(2)
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_a, 0)
m_v = math_ops.reduce_mean(gen_a, 0)
dim = math_ops.to_float(array_ops.shape(m)[0])
sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim
# Take matrix square root of the product of covariance matrices.
sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))
# Compute the two components of FID.
trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
fid = trace + mean
return fid
frechet_inception_distance = functools.partial(
frechet_classifier_distance,
classifier_fn=functools.partial(
run_inception, output_tensor=INCEPTION_V3_FINAL_POOL))

View File

@ -0,0 +1,316 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFGAN classifier_metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
import tempfile
import numpy as np
from google.protobuf import text_format
from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl as classifier_metrics
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
mock = test.mock
def _numpy_softmax(x):
e_x = np.exp(x - np.max(x, axis=1)[:, None])
return e_x / np.sum(e_x, axis=1)[:, None]
def _expected_inception_score(logits):
p = _numpy_softmax(logits)
q = np.expand_dims(np.mean(p, 0), 0)
per_example_logincscore = np.sum(p * (np.log(p) - np.log(q)), 1)
return np.exp(np.mean(per_example_logincscore))
def _approximate_matrix_sqrt(mat, eps=1e-8):
s, u, v = np.linalg.svd(mat)
si = np.where(s < eps, s, np.sqrt(s))
return np.dot(np.dot(u, np.diag(si)), v.T)
def _expected_fid(real_imgs, gen_imgs):
real_imgs = np.asarray(real_imgs)
gen_imgs = np.asarray(gen_imgs)
m = np.mean(real_imgs, axis=0)
m_v = np.mean(gen_imgs, axis=0)
dim = float(m.shape[0])
sigma = np.dot((real_imgs - m), (real_imgs - m).T) / dim
sigma_v = np.dot((gen_imgs - m), (gen_imgs - m).T) / dim
sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v))
mean = np.square(np.linalg.norm(m - m_v))
trace = np.trace(sigma + sigma_v - 2 * sqcc)
fid = mean + trace
return fid
# A dummy GraphDef string with the minimum number of Ops.
graphdef_string = """
node {
name: "inputs"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 299
}
dim {
size: 299
}
dim {
size: 3
}
}
}
}
}
node {
name: "InceptionV3/Logits/SpatialSqueeze"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 1001
}
}
}
}
}
node {
name: "InceptionV3/Logits/AvgPool_1a_8x8/AvgPool"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 2048
}
}
}
}
}
versions {
producer: 24
}
"""
def _get_dummy_graphdef():
dummy_graphdef = graph_pb2.GraphDef()
text_format.Merge(graphdef_string, dummy_graphdef)
return dummy_graphdef
def _run_with_mock(function, *args, **kwargs):
with mock.patch.object(
classifier_metrics,
'get_graph_def_from_url_tarball') as mock_tarball_getter:
mock_tarball_getter.return_value = _get_dummy_graphdef()
return function(*args, **kwargs)
class ClassifierMetricsTest(test.TestCase):
def test_run_inception_graph(self):
"""Test `run_inception` graph construction."""
batch_size = 7
img = array_ops.ones([batch_size, 299, 299, 3])
logits = _run_with_mock(classifier_metrics.run_inception, img)
self.assertTrue(isinstance(logits, ops.Tensor))
logits.shape.assert_is_compatible_with([batch_size, 1001])
# Check that none of the model variables are trainable.
self.assertListEqual([], variables.trainable_variables())
def test_run_inception_graph_pool_output(self):
"""Test `run_inception` graph construction with pool output."""
batch_size = 3
img = array_ops.ones([batch_size, 299, 299, 3])
pool = _run_with_mock(
classifier_metrics.run_inception, img,
output_tensor=classifier_metrics.INCEPTION_V3_FINAL_POOL)
self.assertTrue(isinstance(pool, ops.Tensor))
pool.shape.assert_is_compatible_with([batch_size, 2048])
# Check that none of the model variables are trainable.
self.assertListEqual([], variables.trainable_variables())
def test_inception_score_graph(self):
"""Test `inception_score` graph construction."""
score = _run_with_mock(classifier_metrics.inception_score,
array_ops.zeros([6, 299, 299, 3]), num_batches=3)
self.assertTrue(isinstance(score, ops.Tensor))
score.shape.assert_has_rank(0)
# Check that none of the model variables are trainable.
self.assertListEqual([], variables.trainable_variables())
def test_frechet_inception_distance_graph(self):
"""Test `frechet_inception_distance` graph construction."""
img = array_ops.ones([7, 299, 299, 3])
distance = _run_with_mock(
classifier_metrics.frechet_inception_distance, img, img)
self.assertTrue(isinstance(distance, ops.Tensor))
distance.shape.assert_has_rank(0)
# Check that none of the model variables are trainable.
self.assertListEqual([], variables.trainable_variables())
def test_run_inception_multicall(self):
"""Test that `run_inception` can be called multiple times."""
for batch_size in (7, 3, 2):
img = array_ops.ones([batch_size, 299, 299, 3])
_run_with_mock(classifier_metrics.run_inception, img)
def test_invalid_input(self):
"""Test that functions properly fail on invalid input."""
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))
p = array_ops.zeros([8, 10])
p_logits = array_ops.zeros([8, 10])
q = array_ops.zeros([10])
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
classifier_metrics._kl_divergence(
array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
classifier_metrics._kl_divergence(
p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
classifier_metrics._kl_divergence(
p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)
with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8]))
def test_inception_score_value(self):
"""Test that `inception_score` gives the correct value."""
logits = np.array([np.array([1, 2] * 500 + [4]),
np.array([4, 5] * 500 + [6])])
unused_image = array_ops.zeros([2, 299, 299, 3])
incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
with self.test_session(use_gpu=True) as sess:
incscore_np = sess.run(incscore, {'concat:0': logits})
self.assertAllClose(_expected_inception_score(logits), incscore_np)
def test_frechet_inception_distance_value(self):
"""Test that `frechet_inception_distance` gives the correct value."""
np.random.seed(0)
test_pool_real_a = np.random.randn(5, 2048)
test_pool_gen_a = np.random.randn(5, 2048)
unused_image = array_ops.zeros([5, 299, 299, 3])
pool_a = np.stack((test_pool_real_a, test_pool_gen_a))
fid_op = _run_with_mock(classifier_metrics.frechet_inception_distance,
unused_image, unused_image)
activations_tensor = 'RunClassifier/TensorArrayStack/TensorArrayGatherV3:0'
with self.test_session() as sess:
actual_fid = sess.run(fid_op, {activations_tensor: pool_a})
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
self.assertAllClose(expected_fid, actual_fid, 0.01)
def test_preprocess_image_graph(self):
"""Test `preprocess_image` graph construction."""
incorrectly_sized_image = array_ops.zeros([520, 240, 3])
correct_image = classifier_metrics.preprocess_image(
image=incorrectly_sized_image)
_run_with_mock(classifier_metrics.run_inception,
array_ops.expand_dims(correct_image, 0))
def test_get_graph_def_from_url_tarball(self):
"""Test `get_graph_def_from_url_tarball`."""
# Write dummy binary GraphDef to tempfile.
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(_get_dummy_graphdef().SerializeToString())
relative_path = os.path.relpath(tmp_file.name)
# Create gzip tarball.
tar_dir = tempfile.mkdtemp()
tar_filename = os.path.join(tar_dir, 'tmp.tar.gz')
with tarfile.open(tar_filename, 'w:gz') as tar:
tar.add(relative_path)
with mock.patch.object(classifier_metrics, 'urllib') as mock_urllib:
mock_urllib.request.urlretrieve.return_value = tar_filename, None
graph_def = classifier_metrics.get_graph_def_from_url_tarball(
'unused_url', relative_path)
self.assertIsInstance(graph_def, graph_pb2.GraphDef)
self.assertEqual(_get_dummy_graphdef(), graph_def)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility file for visualizing generated images."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.eval.python import eval_utils_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.eval.python.eval_utils_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = eval_utils_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,134 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility file for visualizing generated images."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
__all__ = [
"image_grid",
"image_reshaper",
]
# TODO(joelshor): Make this a special case of `image_reshaper`.
def image_grid(input_tensor, grid_shape, image_shape=(32, 32), num_channels=3):
"""Arrange a minibatch of images into a grid to form a single image.
Args:
input_tensor: Tensor. Minibatch of images to format, either 4D
([batch size, height, width, num_channels]) or flattened
([batch size, height * width * num_channels]).
grid_shape: Sequence of int. The shape of the image grid,
formatted as [grid_height, grid_width].
image_shape: Sequence of int. The shape of a single image,
formatted as [image_height, image_width].
num_channels: int. The number of channels in an image.
Returns:
Tensor representing a single image in which the input images have been
arranged into a grid.
Raises:
ValueError: The grid shape and minibatch size don't match, or the image
shape and number of channels are incompatible with the input tensor.
"""
if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]):
raise ValueError("Grid shape %s incompatible with minibatch size %i." %
(grid_shape, int(input_tensor.shape[0])))
if len(input_tensor.shape) == 2:
num_features = image_shape[0] * image_shape[1] * num_channels
if int(input_tensor.shape[1]) != num_features:
raise ValueError("Image shape and number of channels incompatible with "
"input tensor.")
elif len(input_tensor.shape) == 4:
if (int(input_tensor.shape[1]) != image_shape[0] or
int(input_tensor.shape[2]) != image_shape[1] or
int(input_tensor.shape[3]) != num_channels):
raise ValueError("Image shape and number of channels incompatible with "
"input tensor.")
else:
raise ValueError("Unrecognized input tensor format.")
height, width = grid_shape[0] * image_shape[0], grid_shape[1] * image_shape[1]
input_tensor = array_ops.reshape(
input_tensor, tuple(grid_shape) + tuple(image_shape) + (num_channels,))
input_tensor = array_ops.transpose(input_tensor, [0, 1, 3, 2, 4])
input_tensor = array_ops.reshape(
input_tensor, [grid_shape[0], width, image_shape[0], num_channels])
input_tensor = array_ops.transpose(input_tensor, [0, 2, 1, 3])
input_tensor = array_ops.reshape(
input_tensor, [1, height, width, num_channels])
return input_tensor
def _validate_images(images):
for img in images:
img.shape.assert_has_rank(3)
img.shape.assert_is_fully_defined()
if img.shape[-1] not in (1, 3):
raise ValueError("image_reshaper only supports 1 or 3 channel images.")
# TODO(joelshor): Move the dimension logic from Python to Tensorflow.
def image_reshaper(images, num_cols=None):
"""A reshaped summary image.
Returns an image that will contain all elements in the list and will be
laid out in a nearly-square tiling pattern (e.g. 11 images will lead to a
3x4 tiled image).
Args:
images: Image data to summarize. Can be an RGB or grayscale image, a list of
such images, or a set of RGB images concatenated along the depth
dimension. The shape of each image is assumed to be [batch_size,
height, width, depth].
num_cols: (Optional) If provided, this is the number of columns in the final
output image grid. Otherwise, the number of columns is determined by
the number of images.
Returns:
A summary image matching the input with automatic tiling if needed.
Output shape is [1, height, width, channels].
"""
if isinstance(images, ops.Tensor):
images = array_ops.unstack(images)
_validate_images(images)
num_images = len(images)
num_columns = (num_cols if num_cols else
int(math.ceil(math.sqrt(num_images))))
num_rows = int(math.ceil(float(num_images) / num_columns))
rows = [images[x:x+num_columns] for x in range(0, num_images, num_columns)]
# Add empty image tiles if the last row is incomplete.
num_short = num_rows * num_columns - num_images
assert num_short >= 0 and num_short < num_columns
if num_short > 0:
rows[-1].extend([array_ops.zeros_like(images[-1])] * num_short)
# Convert each row from a list of tensors to a single tensor.
rows = [array_ops.concat(row, 1) for row in rows]
# Stack rows vertically.
img = array_ops.concat(rows, 0)
return array_ops.expand_dims(img, 0)

View File

@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for eval_utils_test."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.eval.python import eval_utils_impl as eval_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class UtilsTest(test.TestCase):
def test_image_grid(self):
eval_utils.image_grid(
input_tensor=array_ops.zeros([25, 32, 32, 3]),
grid_shape=(5, 5))
# TODO(joelshor): Add more `image_reshaper` tests.
def test_image_reshaper_image_list(self):
images = eval_utils.image_reshaper(
images=array_ops.unstack(array_ops.zeros([25, 32, 32, 3])),
num_cols=2)
images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
def test_image_reshaper_image(self):
images = eval_utils.image_reshaper(
images=array_ops.zeros([25, 32, 32, 3]),
num_cols=2)
images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common TFGAN summaries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.eval.python import summaries_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.eval.python.summaries_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = summaries_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,157 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common TFGAN summaries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import util as loss_util
from tensorflow.python.summary import summary
__all__ = [
'add_gan_model_image_summaries',
'add_image_comparison_summaries',
'add_gan_model_summaries',
'add_regularization_loss_summaries',
]
def _assert_is_image(data):
data.shape.assert_has_rank(4)
data.shape[1:].assert_is_fully_defined()
def add_gan_model_image_summaries(gan_model, grid_size=10):
"""Adds image summaries for real and fake images.
Args:
gan_model: A GANModel tuple.
grid_size: The size of an image grid.
Raises:
ValueError: If real and generated data aren't images.
"""
_assert_is_image(gan_model.real_data)
_assert_is_image(gan_model.generated_data)
num_images = grid_size ** 2
real_image_shape = gan_model.real_data.shape.as_list()[1:3]
generated_image_shape = gan_model.generated_data.shape.as_list()[1:3]
real_channels = gan_model.real_data.shape.as_list()[3]
generated_channels = gan_model.generated_data.shape.as_list()[3]
summary.image(
'real_data',
eval_utils.image_grid(
gan_model.real_data[:num_images],
grid_shape=(grid_size, grid_size),
image_shape=real_image_shape,
num_channels=real_channels),
max_outputs=1)
summary.image(
'generated_data',
eval_utils.image_grid(
gan_model.generated_data[:num_images],
grid_shape=(grid_size, grid_size),
image_shape=generated_image_shape,
num_channels=generated_channels),
max_outputs=1)
add_gan_model_summaries(gan_model)
def add_image_comparison_summaries(gan_model, num_comparisons=2,
display_diffs=False):
"""Adds image summaries to compare triplets of images.
The first image is the generator input, the second is the generator output,
and the third is the real data. This style of comparison is useful for
image translation problems, where the generator input is a corrupted image,
the generator output is the reconstruction, and the real data is the target.
Args:
gan_model: A GANModel tuple.
num_comparisons: The number of image triplets to display.
display_diffs: Also display the difference between generated and target.
Raises:
ValueError: If real data, generated data, and generator inputs aren't
images.
ValueError: If the generator input, real, and generated data aren't all the
same size.
"""
_assert_is_image(gan_model.generator_inputs)
_assert_is_image(gan_model.generated_data)
_assert_is_image(gan_model.real_data)
gan_model.generated_data.shape.assert_is_compatible_with(
gan_model.generator_inputs.shape)
gan_model.real_data.shape.assert_is_compatible_with(
gan_model.generated_data.shape)
image_list = []
image_list.extend(
array_ops.unstack(gan_model.generator_inputs[:num_comparisons]))
image_list.extend(
array_ops.unstack(gan_model.generated_data[:num_comparisons]))
image_list.extend(array_ops.unstack(gan_model.real_data[:num_comparisons]))
if display_diffs:
generated_list = array_ops.unstack(
gan_model.generated_data[:num_comparisons])
real_list = array_ops.unstack(gan_model.real_data[:num_comparisons])
diffs = [
math_ops.abs(math_ops.to_float(generated) - math_ops.to_float(real)) for
generated, real in zip(generated_list, real_list)]
image_list.extend(diffs)
# Reshape image and display.
summary.image(
'image_comparison',
eval_utils.image_reshaper(image_list, num_cols=num_comparisons),
max_outputs=1)
def add_gan_model_summaries(gan_model):
"""Adds typical GANModel summaries.
Args:
gan_model: A GANModel tuple.
"""
with ops.name_scope('generator_variables'):
for var in gan_model.generator_variables:
summary.histogram(var.name, var)
with ops.name_scope('discriminator_variables'):
for var in gan_model.discriminator_variables:
summary.histogram(var.name, var)
def add_regularization_loss_summaries(gan_model):
"""Adds summaries for a regularization losses..
Args:
gan_model: A GANModel tuple.
"""
if gan_model.generator_scope:
summary.scalar(
'generator_regularization_loss',
loss_util.get_regularization_loss(gan_model.generator_scope.name))
if gan_model.discriminator_scope:
summary.scalar(
'discriminator_regularization_loss',
loss_util.get_regularization_loss(gan_model.discriminator_scope.name))

View File

@ -0,0 +1,96 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFGAN summaries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
def generator_model(inputs):
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
pass
with variable_scope.variable_scope('discriminator') as dis_scope:
pass
return namedtuples.GANModel(
generator_inputs=array_ops.zeros([4, 32, 32, 3]),
generated_data=array_ops.zeros([4, 32, 32, 3]),
generator_variables=[variables.Variable(0), variables.Variable(1)],
generator_scope=gen_scope,
generator_fn=generator_model,
real_data=array_ops.ones([4, 32, 32, 3]),
discriminator_real_outputs=array_ops.ones([1, 2, 3]),
discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
discriminator_variables=[variables.Variable(0)],
discriminator_scope=dis_scope,
discriminator_fn=discriminator_model)
class SummariesTest(test.TestCase):
def testAddGanModelImageSummaries(self):
summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2)
self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
variables.global_variables_initializer().run()
summary.merge_all().eval()
def testAddGanModelSummaries(self):
summaries.add_gan_model_summaries(get_gan_model())
self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
variables.global_variables_initializer().run()
summary.merge_all().eval()
def testAddRegularizationLossSummaries(self):
summaries.add_regularization_loss_summaries(get_gan_model())
self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
summary.merge_all().eval()
# TODO(joelshor): Add correctness test.
def testAddImageComparisonSummaries(self):
summaries.add_image_comparison_summaries(
get_gan_model(), display_diffs=True)
self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
summary.merge_all().eval()
if __name__ == '__main__':
test.main()

View File

@ -88,6 +88,7 @@ cuda_py_test(
size = "medium",
srcs = ["python/kernel_tests/image_ops_test.py"],
additional_deps = [
":distort_image_py",
":image_py",
":single_image_random_dot_stereograms_py",
"//third_party/py/numpy",
@ -99,6 +100,80 @@ cuda_py_test(
],
)
tf_custom_op_library(
name = "python/ops/_distort_image_ops.so",
srcs = [
"kernels/adjust_hsv_in_yiq_op.cc",
"ops/distort_image_ops.cc",
],
deps = [
"@protobuf_archive//:protobuf",
],
)
tf_gen_op_libs(
op_lib_names = ["distort_image_ops"],
)
tf_gen_op_wrapper_py(
name = "distort_image_ops",
deps = [":distort_image_ops_op_lib"],
)
cc_library(
name = "distort_image_ops_cc",
srcs = [
"kernels/adjust_hsv_in_yiq_op.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
py_library(
name = "distort_image_py",
srcs = [
"__init__.py",
"python/ops/distort_image_ops.py",
],
data = [":python/ops/_distort_image_ops.so"],
srcs_version = "PY2AND3",
deps = [
":distort_image_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:image_ops",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
],
)
cuda_py_test(
name = "distort_image_ops_test",
size = "medium",
srcs = ["python/kernel_tests/distort_image_ops_test.py"],
additional_deps = [
":distort_image_py",
":image_py",
":single_image_random_dot_stereograms_py",
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
"//tensorflow/core:protos_all_py",
],
)
tf_custom_op_library(
name = "python/ops/_single_image_random_dot_stereograms.so",
srcs = [

View File

@ -16,11 +16,14 @@
### API
This module provides functions for image manipulation; currently, only
This module provides functions for image manipulation; currently, chrominance
transformas (including changing saturation and hue) in YIQ space and
projective transforms (including rotation) are supported.
@@angles_to_projective_transforms
@@compose_transforms
@@adjust_yiq_hsv
@@random_yiq_hsv
@@rotate
@@transform
@@bipartite_match
@ -31,6 +34,9 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq
from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq
from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
from tensorflow.contrib.image.python.ops.image_ops import rotate
@ -39,5 +45,6 @@ from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms imp
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=line-too-long
remove_undocumented(__name__)

View File

@ -0,0 +1,172 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <memory>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
class AdjustHsvInYiqOpBase : public OpKernel {
protected:
explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
struct ComputeOptions {
const Tensor* input = nullptr;
const Tensor* delta_h = nullptr;
const Tensor* scale_s = nullptr;
const Tensor* scale_v = nullptr;
Tensor* output = nullptr;
int64 channel_count = 0;
};
virtual void DoCompute(OpKernelContext* context,
const ComputeOptions& options) = 0;
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const Tensor& delta_h = context->input(1);
const Tensor& scale_s = context->input(2);
const Tensor& scale_v = context->input(3);
OP_REQUIRES(context, input.dims() >= 3,
errors::InvalidArgument("input must be at least 3-D, got shape",
input.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
errors::InvalidArgument("delta_h must be scalar: ",
delta_h.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
errors::InvalidArgument("scale_s must be scalar: ",
scale_s.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
errors::InvalidArgument("scale_v must be scalar: ",
scale_v.shape().DebugString()));
auto channels = input.dim_size(input.dims() - 1);
OP_REQUIRES(
context, channels == 3,
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
if (input.NumElements() > 0) {
const int64 channel_count = input.NumElements() / channels;
ComputeOptions options;
options.input = &input;
options.delta_h = &delta_h;
options.scale_s = &scale_s;
options.scale_v = &scale_v;
options.output = output;
options.channel_count = channel_count;
DoCompute(context, options);
}
}
};
template <class Device>
class AdjustHsvInYiqOp;
template <>
class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
public:
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
: AdjustHsvInYiqOpBase(context) {}
void DoCompute(OpKernelContext* context,
const ComputeOptions& options) override {
const Tensor* input = options.input;
Tensor* output = options.output;
const int64 channel_count = options.channel_count;
static const int kChannelSize = 3;
auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
const float delta_h = options.delta_h->scalar<float>()();
const float scale_s = options.scale_s->scalar<float>()();
const float scale_v = options.scale_v->scalar<float>()();
auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
const int kCostPerChannel = 10;
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
kCostPerChannel,
[channel_count, &input_data, &output_data, delta_h, scale_s, scale_v](
int64 start_channel, int64 end_channel) {
// Using approximate linear transfomation described in:
// https://beesbuzz.biz/code/hsv_color_transforms.php
/** Get the constants from sympy
from sympy import Matrix
from sympy.abc import u, w
# Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ
tyiq = Matrix([[0.299, 0.587, 0.114],
[0.596, -0.274, -0.322],
[0.211, -0.523, 0.312]])
# Hue rotation matrix in YIQ space.
hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu])
m = tyiq.inv() * hue_proj * tyiq
**/
// TODO(huangyp): directly compute the projection matrix from tyiq.
static const float t[kChannelSize][kChannelSize][kChannelSize] = {
{{.299, .701, .16862179492229},
{.587, -.587, .329804745287403},
{.114, -.114, -0.498426540209694}},
{{.299, -.299, -.327963394172371},
{.587, .413, .0346106879248821},
{.114, -.114, .293352706247489}},
{{.299, -.299, 1.24646136576682},
{.587, -.587, -1.04322888291964},
{.114, .886, -.203232482847173}}};
float m[kChannelSize][kChannelSize] = {{0.}};
float su = scale_s * std::cos(delta_h);
float sw = scale_s * std::sin(delta_h);
for (int q_index = 0; q_index < kChannelSize; q_index++) {
for (int p_index = 0; p_index < kChannelSize; p_index++) {
m[q_index][p_index] = scale_v * (t[q_index][p_index][0] +
t[q_index][p_index][1] * su +
t[q_index][p_index][2] * sw);
}
}
// Applying projection matrix to input RGB vectors.
const float* p = input_data.data() + start_channel * kChannelSize;
float* q = output_data.data() + start_channel * kChannelSize;
for (int i = start_channel; i < end_channel; i++) {
for (int q_index = 0; q_index < kChannelSize; q_index++) {
q[q_index] = 0;
for (int p_index = 0; p_index < kChannelSize; p_index++) {
q[q_index] += m[q_index][p_index] * p[p_index];
}
}
p += kChannelSize;
q += kChannelSize;
}
});
}
};
REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU),
AdjustHsvInYiqOp<CPUDevice>);
// TODO(huangyp): add the GPU kernel
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
// --------------------------------------------------------------------------
REGISTER_OP("AdjustHsvInYiq")
.Input("images: T")
.Input("delta_h: float")
.Input("scale_s: float")
.Input("scale_v: float")
.Output("output: T")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
})
.Doc(R"Doc(
Adjust the YIQ hue of one or more images.
`images` is a tensor of at least 3 dimensions. The last dimension is
interpretted as channels, and must be three.
We used linear transfomation described in:
beesbuzz.biz/code/hsv_color_transforms.php
The input image is considered in the RGB colorspace. Conceptually, the RGB
colors are first mapped into YIQ space, rotated around the Y channel by
delta_h in radians, multiplying the chrominance channels (I, Q) by scale_s,
multiplying all channels (Y, I, Q) by scale_v, and then remapped back to RGB
colorspace. Each operation described above is a linear transformation.
images: Images to adjust. At least 3-D.
delta_h: A float scale that represents the hue rotation amount, in radians.
Although delta_h can be any float value.
scale_s: A float scale that represents the factor to multiply the saturation by.
scale_s needs to be non-negative.
scale_v: A float scale that represents the factor to multiply the value by.
scale_v needs to be non-negative.
output: The hsv-adjusted image or images. No clipping will be done in this op.
The client can clip them using additional ops in their graph.
)Doc");
} // namespace tensorflow

View File

@ -0,0 +1,338 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for python distort_image_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.image.python.ops import distort_image_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
# TODO(huangyp): also measure the differences between AdjustHsvInYiq and
# AdjustHsv in core.
class AdjustHueInYiqTest(test_util.TensorFlowTestCase):
def _adjust_hue_in_yiq_np(self, x_np, delta_h):
"""Rotate hue in YIQ space.
Mathematically we first convert rgb color to yiq space, rotate the hue
degrees, and then convert back to rgb.
Args:
x_np: input x with last dimension = 3.
delta_h: degree of hue rotation, in radians.
Returns:
Adjusted y with the same shape as x_np.
"""
self.assertEqual(x_np.shape[-1], 3)
x_v = x_np.reshape([-1, 3])
y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
u = np.cos(delta_h)
w = np.sin(delta_h)
# Projection matrix from RGB to YIQ. Numbers from wikipedia
# https://en.wikipedia.org/wiki/YIQ
tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.322],
[0.211, -0.523, 0.312]])
y_v = np.dot(x_v, tyiq.T)
# Hue rotation matrix in YIQ space.
hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
y_v = np.dot(y_v, hue_rotation.T)
# Projecting back to RGB space.
y_v = np.dot(y_v, np.linalg.inv(tyiq).T)
return y_v.reshape(x_np.shape)
def _adjust_hue_in_yiq_tf(self, x_np, delta_h):
with self.test_session(use_gpu=True):
x = constant_op.constant(x_np)
y = distort_image_ops.adjust_hsv_in_yiq(x, delta_h, 1, 1)
y_tf = y.eval()
return y_tf
def test_adjust_random_hue_in_yiq(self):
x_shapes = [
[2, 2, 3],
[4, 2, 3],
[2, 4, 3],
[2, 5, 3],
[1000, 1, 3],
]
test_styles = [
'all_random',
'rg_same',
'rb_same',
'gb_same',
'rgb_same',
]
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
if test_style == 'all_random':
pass
elif test_style == 'rg_same':
x_np[..., 1] = x_np[..., 0]
elif test_style == 'rb_same':
x_np[..., 2] = x_np[..., 0]
elif test_style == 'gb_same':
x_np[..., 2] = x_np[..., 1]
elif test_style == 'rgb_same':
x_np[..., 1] = x_np[..., 0]
x_np[..., 2] = x_np[..., 0]
else:
raise AssertionError('Invalid test style: %s' % (test_style))
y_np = self._adjust_hue_in_yiq_np(x_np, delta_h)
y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
def test_invalid_shapes(self):
x_np = np.random.rand(2, 3) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
self._adjust_hue_in_yiq_tf(x_np, delta_h)
x_np = np.random.rand(4, 2, 4) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError('input must have 3 channels but instead has '
'4 channels'):
self._adjust_hue_in_yiq_tf(x_np, delta_h)
class AdjustValueInYiqTest(test_util.TensorFlowTestCase):
def _adjust_value_in_yiq_np(self, x_np, scale):
return x_np * scale
def _adjust_value_in_yiq_tf(self, x_np, scale):
with self.test_session(use_gpu=True):
x = constant_op.constant(x_np)
y = distort_image_ops.adjust_hsv_in_yiq(x, 0, 1, scale)
y_tf = y.eval()
return y_tf
def test_adjust_random_value_in_yiq(self):
x_shapes = [
[2, 2, 3],
[4, 2, 3],
[2, 4, 3],
[2, 5, 3],
[1000, 1, 3],
]
test_styles = [
'all_random',
'rg_same',
'rb_same',
'gb_same',
'rgb_same',
]
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
scale = np.random.rand() * 2.0 - 1.0
if test_style == 'all_random':
pass
elif test_style == 'rg_same':
x_np[..., 1] = x_np[..., 0]
elif test_style == 'rb_same':
x_np[..., 2] = x_np[..., 0]
elif test_style == 'gb_same':
x_np[..., 2] = x_np[..., 1]
elif test_style == 'rgb_same':
x_np[..., 1] = x_np[..., 0]
x_np[..., 2] = x_np[..., 0]
else:
raise AssertionError('Invalid test style: %s' % (test_style))
y_np = self._adjust_value_in_yiq_np(x_np, scale)
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
def test_invalid_shapes(self):
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
self._adjust_value_in_yiq_tf(x_np, scale)
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError('input must have 3 channels but instead has '
'4 channels'):
self._adjust_value_in_yiq_tf(x_np, scale)
class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
def _adjust_saturation_in_yiq_tf(self, x_np, scale):
with self.test_session(use_gpu=True):
x = constant_op.constant(x_np)
y = distort_image_ops.adjust_hsv_in_yiq(x, 0, scale, 1)
y_tf = y.eval()
return y_tf
def _adjust_saturation_in_yiq_np(self, x_np, scale):
"""Adjust saturation using linear interpolation."""
rgb_weights = np.array([0.299, 0.587, 0.114])
gray = np.sum(x_np * rgb_weights, axis=-1, keepdims=True)
y_v = x_np * scale + gray * (1 - scale)
return y_v
def test_adjust_random_saturation_in_yiq(self):
x_shapes = [
[2, 2, 3],
[4, 2, 3],
[2, 4, 3],
[2, 5, 3],
[1000, 1, 3],
]
test_styles = [
'all_random',
'rg_same',
'rb_same',
'gb_same',
'rgb_same',
]
with self.test_session():
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
scale = np.random.rand() * 2.0 - 1.0
if test_style == 'all_random':
pass
elif test_style == 'rg_same':
x_np[..., 1] = x_np[..., 0]
elif test_style == 'rb_same':
x_np[..., 2] = x_np[..., 0]
elif test_style == 'gb_same':
x_np[..., 2] = x_np[..., 1]
elif test_style == 'rgb_same':
x_np[..., 1] = x_np[..., 0]
x_np[..., 2] = x_np[..., 0]
else:
raise AssertionError('Invalid test style: %s' % (test_style))
y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale)
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5)
def test_invalid_shapes(self):
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
self._adjust_saturation_in_yiq_tf(x_np, scale)
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError('input must have 3 channels but instead has '
'4 channels'):
self._adjust_saturation_in_yiq_tf(x_np, scale)
class AdjustHueInYiqBenchmark(test.Benchmark):
def _benchmark_adjust_hue_in_yiq(self, device, cpu_count):
image_shape = [299, 299, 3]
warmup_rounds = 100
benchmark_rounds = 1000
config = config_pb2.ConfigProto()
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
with session.Session('', graph=ops.Graph(), config=config) as sess:
with ops.device(device):
inputs = variables.Variable(
random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
trainable=False,
dtype=dtypes.float32)
delta = constant_op.constant(0.1, dtype=dtypes.float32)
outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, delta, 1, 1)
run_op = control_flow_ops.group(outputs)
sess.run(variables.global_variables_initializer())
for i in xrange(warmup_rounds + benchmark_rounds):
if i == warmup_rounds:
start = time.time()
sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + '_%s' % (cpu_count if cpu_count is not None else 'all')
print('benchmarkadjust_hue_in_yiq_299_299_3_%s step_time: %.2f us' %
(tag, step_time * 1e6))
self.report_benchmark(
name='benchmarkadjust_hue_in_yiq_299_299_3_%s' % (tag),
iters=benchmark_rounds,
wall_time=step_time)
def benchmark_adjust_hue_in_yiqCpu1(self):
self._benchmark_adjust_hue_in_yiq('/cpu:0', 1)
def benchmark_adjust_hue_in_yiqCpuAll(self):
self._benchmark_adjust_hue_in_yiq('/cpu:0', None)
class AdjustSaturationInYiqBenchmark(test.Benchmark):
def _benchmark_adjust_saturation_in_yiq(self, device, cpu_count):
image_shape = [299, 299, 3]
warmup_rounds = 100
benchmark_rounds = 1000
config = config_pb2.ConfigProto()
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
with session.Session('', graph=ops.Graph(), config=config) as sess:
with ops.device(device):
inputs = variables.Variable(
random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
trainable=False,
dtype=dtypes.float32)
scale = constant_op.constant(0.1, dtype=dtypes.float32)
outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, 0, scale, 1)
run_op = control_flow_ops.group(outputs)
sess.run(variables.global_variables_initializer())
for _ in xrange(warmup_rounds):
sess.run(run_op)
start = time.time()
for _ in xrange(benchmark_rounds):
sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = '%s' % (cpu_count) if cpu_count is not None else '_all'
print('benchmarkAdjustSaturationInYiq_299_299_3_cpu%s step_time: %.2f us' %
(tag, step_time * 1e6))
self.report_benchmark(
name='benchmarkAdjustSaturationInYiq_299_299_3_cpu%s' % (tag),
iters=benchmark_rounds,
wall_time=step_time)
def benchmark_adjust_saturation_in_yiq_cpu1(self):
self._benchmark_adjust_saturation_in_yiq('/cpu:0', 1)
def benchmark_adjust_saturation_in_yiq_cpu_all(self):
self._benchmark_adjust_saturation_in_yiq('/cpu:0', None)
if __name__ == '__main__':
googletest.main()

View File

@ -0,0 +1,138 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python layer for distort_image_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import resource_loader
_distort_image_ops = loader.load_op_library(
resource_loader.get_path_to_datafile('_distort_image_ops.so'))
# pylint: disable=invalid-name
def random_hsv_in_yiq(image,
max_delta_hue=0,
lower_saturation=1,
upper_saturation=1,
lower_value=1,
upper_value=1,
seed=None):
"""Adjust hue, saturation, value of an RGB image randomly in YIQ color space.
Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly
picked in the interval `[-max_delta_hue, max_delta_hue]`, a `scale_saturation`
randomly picked in the interval `[lower_saturation, upper_saturation]`, and
a `scale_value` randomly picked in the interval
`[lower_saturation, upper_saturation]`.
Args:
image: RGB image or images. Size of the last dimension must be 3.
max_delta_hue: float. Maximum value for the random delta_hue. Passing 0
disables adjusting hue.
lower_saturation: float. Lower bound for the random scale_saturation.
upper_saturation: float. Upper bound for the random scale_saturation.
lower_value: float. Lower bound for the random scale_value.
upper_value: float. Upper bound for the random scale_value.
seed: An operation-specific seed. It will be used in conjunction
with the graph-level seed to determine the real seeds that will be
used in this operation. Please see the documentation of
set_random_seed for its interaction with the graph-level random seed.
Returns:
3-D float tensor of shape `[height, width, channels]`.
Raises:
ValueError: if `max_delta`, `lower_saturation`, `upper_saturation`,
`lower_value`, or `upper_Value` is invalid.
"""
if max_delta_hue < 0:
raise ValueError('max_delta must be non-negative.')
if lower_saturation < 0:
raise ValueError('lower_saturation must be non-negative.')
if lower_value < 0:
raise ValueError('lower_value must be non-negative.')
if lower_saturation > upper_saturation:
raise ValueError('lower_saturation must be < upper_saturation.')
if lower_value > upper_value:
raise ValueError('lower_value must be < upper_value.')
if max_delta_hue == 0:
delta_hue = 0
else:
delta_hue = random_ops.random_uniform(
[], -max_delta_hue, max_delta_hue, seed=seed)
if lower_saturation == upper_saturation:
scale_saturation = lower_saturation
else:
scale_saturation = random_ops.random_uniform(
[], lower_saturation, upper_saturation, seed=seed)
if lower_value == upper_value:
scale_value = lower_value
else:
scale_value = random_ops.random_uniform(
[], lower_value, upper_value, seed=seed)
return adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value)
def adjust_hsv_in_yiq(image,
delta_hue=0,
scale_saturation=1,
scale_value=1,
name=None):
"""Adjust hue, saturation, value of an RGB image in YIQ color space.
This is a convenience method that converts an RGB image to float
representation, converts it to YIQ, rotates the color around the Y channel by
delta_hue in radians, scales the chrominance channels (I, Q) by
scale_saturation, scales all channels (Y, I, Q) by scale_value,
converts back to RGB, and then back to the original data type.
`image` is an RGB image. The image hue is adjusted by converting the
image to YIQ, rotating around the luminance channel (Y) by
`delta_hue` in radians, multiplying the chrominance channels (I, Q) by
`scale_saturation`, and multiplying all channels (Y, I, Q) by
`scale_value`. The image is then converted back to RGB.
Args:
image: RGB image or images. Size of the last dimension must be 3.
delta_hue: float, the hue rotation amount, in radians.
scale_saturation: float, factor to multiply the saturation by.
scale_value: float, factor to multiply the value by.
name: A name for this operation (optional).
Returns:
Adjusted image(s), same shape and DType as `image`.
"""
with ops.name_scope(name, 'adjust_hsv_in_yiq', [image]) as name:
image = ops.convert_to_tensor(image, name='image')
# Remember original dtype to so we can convert back if needed
orig_dtype = image.dtype
flt_image = image_ops.convert_image_dtype(image, dtypes.float32)
rgb_altered = _distort_image_ops.adjust_hsv_in_yiq(
flt_image, delta_hue, scale_saturation, scale_value)
return image_ops.convert_image_dtype(rgb_altered, orig_dtype)

View File

@ -1678,9 +1678,14 @@ class _MultiHead(Head):
ModelFnOps that merges all heads for TRAIN.
"""
losses = []
metrics = {}
additional_train_ops = []
for m in all_model_fn_ops:
losses.append(m.loss)
if m.eval_metric_ops is not None:
for k, v in six.iteritems(m.eval_metric_ops):
# metrics["%s/%s" % (k, head_name)] = v
metrics[k] = v
additional_train_ops.append(m.train_op)
loss = self._loss_merger(losses)
@ -1689,7 +1694,8 @@ class _MultiHead(Head):
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.TRAIN,
loss=loss,
train_op=train_op)
train_op=train_op,
eval_metric_ops=metrics)
def _merge_infer(self, all_model_fn_ops):
"""Merges list of ModelFnOps for inference.

View File

@ -1703,7 +1703,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
self.assertFalse(model_fn_ops.eval_metric_ops)
self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
@ -1728,7 +1728,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
self.assertFalse(model_fn_ops.eval_metric_ops)
self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
@ -1755,7 +1755,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
self.assertFalse(model_fn_ops.eval_metric_ops)
self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:

View File

@ -113,7 +113,7 @@ struct GatherTree<CPUDevice, int32> {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
int32 seq_len_b = sequence_length(batch, beam);
if (seq_len_b == 0) {
if (seq_len_b <= 0) {
continue;
}
beams(seq_len_b - 1, batch, beam) =

View File

@ -33,7 +33,10 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
if (seq_len_b <= 0) continue;
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);

View File

@ -155,7 +155,7 @@ void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
// retrieving the sizes from the wrapped allocator removes the
// executor's reference to it, so allocator_pair.second must not
// be dereferenced again after this statement
auto sizes = allocator_pair.second->GetSizesAndUnRef();
const auto sizes = allocator_pair.second->GetSizesAndUnRef();
memory->set_allocator_name(allocator_pair.first->Name());
memory->set_total_bytes(std::get<0>(sizes));
memory->set_peak_bytes(std::get<1>(sizes));
@ -1373,7 +1373,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
const int out_id = out->id();
// Add to ready queue if not visited.
bool is_visited = visited[out_id];
@ -1417,7 +1417,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
// Ask the device to fill in the device context map.
Device* device = impl_->params_.device;
Status fill_status = device->FillContextMap(graph, &device_context_map_);
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
done(fill_status);
return;
@ -1525,7 +1526,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
inline_ready.pop_front();
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
int64 input_iter = tagged_node.input_iter;
const int64 input_iter = tagged_node.input_iter;
const int id = node->id();
const NodeItem& item = *gview.node(id);
@ -1637,7 +1638,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
}
bool completed =
const bool completed =
NodeDone(s, state->item->node, ready, stats, nullptr);
delete state;
if (completed) Finish();
@ -1803,7 +1804,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
}
for (int i = 0; i < item.num_outputs; ++i) {
TensorValue val = ctx->release_output(i);
const TensorValue val = ctx->release_output(i);
if (*ctx->is_output_dead() || val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
@ -1893,7 +1894,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
TaggedNodeSeq* ready) {
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
int64 input_iter = tagged_node.input_iter;
const int64 input_iter = tagged_node.input_iter;
const bool is_dead = tagged_node.is_dead;
// Propagates outputs along out edges, and puts newly ready nodes
@ -1913,7 +1914,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
bool is_constant;
Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
@ -1983,7 +1984,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
// completion of this node makes its frame completed.
if (is_frame_done) {
FrameState* parent_frame = input_frame->parent_frame;
int64 parent_iter = input_frame->parent_iter;
const int64 parent_iter = input_frame->parent_iter;
DeleteFrame(input_frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
@ -2026,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
}
bool completed = false;
size_t ready_size = ready.size();
const size_t ready_size = ready.size();
if (ready_size == 0 || !s.ok()) {
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
} else if (ready_size > 1) {
@ -2166,7 +2167,7 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
const std::vector<const Node*>* nodes = frame->nodes;
// Dump any waiting nodes that are holding on to tensors.
for (const Node* node : *nodes) {
int node_id = node->id();
const int node_id = node->id();
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
@ -2175,14 +2176,14 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
}
// Then the active nodes.
for (const Node* node : *nodes) {
int node_id = node->id();
const int node_id = node->id();
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(node_id, iteration->input_tensors);
}
}
// Show all input tensors in use.
int total_input_tensors = frame->total_input_tensors;
const int total_input_tensors = frame->total_input_tensors;
size_t total_bytes = 0;
for (int i = 0; i < total_input_tensors; ++i) {
const Entry& input = iteration->input_tensors[i];
@ -2291,7 +2292,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
// First, propagate dead_exits (if any) to the parent frame.
FrameState* parent_frame = frame->parent_frame;
int64 parent_iter = frame->parent_iter;
const int64 parent_iter = frame->parent_iter;
if (parent_frame != nullptr) {
mutex_lock paranet_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame.
@ -2300,7 +2301,8 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
for (const Edge* e : node->out_edges()) {
const Node* dst_node = e->dst();
auto dst_pending_id = impl_->gview_.node(dst_node->id())->pending_id;
const auto dst_pending_id =
impl_->gview_.node(dst_node->id())->pending_id;
// TODO(yuanbyu): We don't need this if we require the subgraph
// given to an executor not to contain a sink node.
@ -2358,7 +2360,7 @@ void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
}
if (is_frame_done) {
FrameState* parent_frame = frame->parent_frame;
int64 parent_iter = frame->parent_iter;
const int64 parent_iter = frame->parent_iter;
DeleteFrame(frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
@ -2433,7 +2435,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
}
}
} else {
bool increment_dead =
const bool increment_dead =
(is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
int pending, dead;
iter_state->adjust_for_activation(dst_pending_id, increment_dead,
@ -2497,7 +2499,7 @@ void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
inv_values.push_back({item->node, entry});
// Make this value available to all iterations.
bool is_dead = !entry.has_value;
const bool is_dead = !entry.has_value;
for (int i = 0; i <= iteration_count; ++i) {
EntryVector outputs{entry};
ActivateNodes(item, is_dead, i, &outputs, ready);
@ -2522,7 +2524,7 @@ bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
TaggedNodeSeq* ready) {
iteration_count++;
int64 next_iter = iteration_count;
const int64 next_iter = iteration_count;
// Initialize the next iteration.
IterationState* iter_state =
@ -2567,7 +2569,7 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params, graph);
Status s = impl->Initialize();
const Status s = impl->Initialize();
if (s.ok()) {
*executor = impl;
} else {
@ -2579,7 +2581,7 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const NodeDef& ndef, int graph_def_version,
OpKernel** kernel) {
auto device_type = DeviceType(device->attributes().device_type());
const auto device_type = DeviceType(device->attributes().device_type());
auto allocator = device->GetAllocator(AllocatorAttributes());
return CreateOpKernel(device_type, device, allocator, flib, ndef,
graph_def_version, kernel);

View File

@ -665,7 +665,9 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_band_part_op",
prefix = "matrix_band_part_op",
deps = ARRAY_DEPS,
deps = if_cuda([
":cuda_solvers",
]) + ARRAY_DEPS,
)
tf_kernel_library(
@ -1332,7 +1334,7 @@ tf_kernel_library(
"transpose_functor_gpu.cu.cc",
"transpose_functor.h",
],
visibility = ["//visibility:private"],
visibility = [":friends"],
deps = [
":ops_util",
"//tensorflow/core:framework",

View File

@ -76,18 +76,19 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void MatrixBandPart<GPUDevice, T>::Compute( \
const GPUDevice& d, Eigen::DenseIndex num_lower, \
Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, bool transpose, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
template <class Scalar>
@ -131,9 +132,9 @@ class CholeskyOpGpu : public AsyncOpKernel {
// before we launch each of the Cholesky factorization kernels in paralle.
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
functor::MatrixBandPart<GPUDevice, Scalar>::Compute(
context->eigen_device<GPUDevice>(), n, 0, input_reshaped,
output_reshaped);
functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
input_reshaped, output_reshaped);
// Launch a Cholesky kernel for each matrix in the batch.
const int64 batch_size = input_reshaped.dimension(0);

View File

@ -1024,7 +1024,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
assert(__popc(kWidth) == 1);
int sub_warp = cub::LaneId() / kWidth;
int zeros = sub_warp * kWidth;
unsigned mask = ((1U << kWidth) - 1) << zeros;
unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
val += CudaShuffleXor(mask, val, delta);
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -48,18 +50,6 @@ class MatrixBandPartOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const Tensor& num_lower_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
const int64 num_lower = num_lower_in.scalar<int64>()();
const Tensor& num_upper_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
const int64 num_upper = num_upper_in.scalar<int64>()();
const TensorShape& input_shape = input.shape();
// Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
@ -67,12 +57,43 @@ class MatrixBandPartOp : public OpKernel {
"input must be at least 2-dim, received shape: ",
input.shape().DebugString()));
auto input_reshaped = input.flat_inner_dims<T, 3>();
const Tensor& num_lower_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
const int64 num_lower = num_lower_in.scalar<int64>()();
OP_REQUIRES(
context, num_lower <= input_reshaped.dimension(1),
errors::InvalidArgument(
"num_lower must be negative or less or equal to number of rows (",
input_reshaped.dimension(1), ") got: ", num_lower));
const Tensor& num_upper_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
const int64 num_upper = num_upper_in.scalar<int64>()();
OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
errors::InvalidArgument("num_upper must be negative or less or "
"equal to number of columns (",
input_reshaped.dimension(2),
") got: ", num_upper));
if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
(num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
// This is a no-op.
context->set_output(0, input);
return;
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
functor::MatrixBandPart<Device, T>::Compute(
context->eigen_device<Device>(), num_lower, num_upper, input_reshaped,
output_reshaped);
functor::MatrixBandPartFunctor<Device, T> fn;
fn(context, context->eigen_device<Device>(), num_lower, num_upper,
false /* transpose */, input_reshaped, output_reshaped);
}
private:
@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
// Implementation of the functor specialization for CPU.
namespace functor {
template <typename T>
struct MatrixBandPart<CPUDevice, T> {
static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper,
typename TTypes<T, 3>::ConstTensor input,
typename TTypes<T, 3>::Tensor output) {
if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
(num_upper < 0 || num_upper >= input.dimension(2))) {
output.device(d) = input;
} else {
output.device(d) = output.constant(T());
for (int64 r = 0; r < output.dimension(0); ++r) {
for (int64 i = 0; i < output.dimension(1); ++i) {
const int64 band_start =
num_lower < 0 ? 0 : std::max(0ll, i - num_lower);
const int64 band_end =
num_upper < 0 ? output.dimension(2)
: std::min(static_cast<int64>(output.dimension(2)),
i + num_upper + 1);
if (band_start < band_end) {
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start);
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
1, 1, band_end - band_start);
output.slice(indices, sizes) = input.slice(indices, sizes);
// CPU implementation of BandPartFunctor.
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Scalar>
struct MatrixBandPartFunctor<CPUDevice, Scalar> {
void operator()(OpKernelContext* context, const CPUDevice& device,
int num_lower_diags, int num_upper_diags, bool transpose,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
const int64 b = input.dimension(0);
const int64 m = input.dimension(1);
const int64 n = input.dimension(2);
auto thread_pool =
context->device()->tensorflow_cpu_worker_threads()->workers;
const int64 total_rows = b * m;
const int64 row_cost = 10 * n;
const bool in_place = input.data() == output.data();
CHECK(!(transpose && in_place));
if (!transpose) {
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
if (!in_place) {
std::fill(output.data() + begin * n, output.data() + end * n,
Scalar());
}
const int64 batch_begin = begin / m;
const int64 batch_end = (end + m - 1) / m;
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
const int64 row_begin = begin > batch * m ? begin % m : 0;
const int64 row_end = end < (batch + 1) * m ? end % m : m;
for (int64 row = row_begin; row < row_end; ++row) {
const int64 band_start =
num_lower_diags < 0
? 0
: std::min(n, std::max(0ll, row - num_lower_diags));
const int64 band_end = num_upper_diags < 0
? n
: std::min(static_cast<int64>(n),
row + num_upper_diags + 1);
if (in_place) {
if (band_start > 0) {
std::fill(&output(batch, row, 0),
&output(batch, row, band_start), Scalar());
}
if (band_end < n) {
std::fill(&output(batch, row, band_end), &output(batch, row, n),
Scalar());
}
} else {
if (band_start < band_end) {
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
band_start);
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
1, 1, band_end - band_start);
output.slice(indices, sizes) = input.slice(indices, sizes);
}
}
}
}
}
};
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
} else {
output.device(device) = output.constant(Scalar());
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
const int64 batch_begin = begin / m;
const int64 batch_end = (end + m - 1) / m;
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
const int64 row_begin = begin > batch * m ? begin % m : 0;
const int64 row_end = end < (batch + 1) * m ? end % m : m;
for (int64 row = row_begin; row < row_end; ++row) {
const int64 band_start =
num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
const int64 band_end = num_upper_diags < 0
? n
: std::min(static_cast<int64>(n),
row + num_upper_diags + 1);
for (int64 col = band_start; col < band_end; ++col) {
output(batch, col, row) = input(batch, row, col);
}
}
}
};
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
}
};
#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
#undef DEFINE_CPU_SPEC
} // namespace functor
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void MatrixBandPart<GPUDevice, T>::Compute( \
const GPUDevice& d, Eigen::DenseIndex num_lower, \
Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, bool transpose, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.

View File

@ -16,61 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
// Generator definition for MatrixBandPartOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace generator {
template <typename T>
class MatrixBandPartGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE MatrixBandPartGenerator(
Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
typename TTypes<T, 3>::ConstTensor input)
: num_lower_(num_lower), num_upper_(num_upper), input_(input) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<Eigen::DenseIndex, 3>& coords) const {
return (((num_lower_ < 0 || coords[1] - coords[2] <= num_lower_) &&
(num_upper_ < 0 || coords[2] - coords[1] <= num_upper_))
? input_(coords)
: T());
}
private:
const Eigen::DenseIndex num_lower_;
const Eigen::DenseIndex num_upper_;
typename TTypes<T, 3>::ConstTensor input_;
};
} // namespace generator
namespace functor {
template <typename Device, typename T>
struct MatrixBandPart {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
typename TTypes<T, 3>::ConstTensor input,
typename TTypes<T, 3>::Tensor output) {
if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
(num_upper < 0 || num_upper >= input.dimension(2))) {
output.device(d) = input;
} else {
generator::MatrixBandPartGenerator<T> generator(num_lower, num_upper,
input);
output.device(d) = output.generate(generator);
}
}
template <typename Device, typename Scalar>
struct MatrixBandPartFunctor {
void operator()(OpKernelContext* context, const Device& device,
int num_upper_diags, int num_lower_diags, bool transpose,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output);
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_

View File

@ -17,22 +17,92 @@ limitations under the License.
#define EIGEN_USE_GPU
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_SPEC(T) \
template class generator::MatrixBandPartGenerator<T>; \
template struct functor::MatrixBandPart<GPUDevice, T>;
template <bool transpose, typename Scalar>
__global__ void MatrixBandPartKernel(const int num_threads,
const int batch_size, const int m,
const int n, const int num_lower_diags,
const int num_upper_diags,
const Scalar* input_ptr,
Scalar* output_ptr) {
if (!transpose) {
CUDA_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
const int band_end =
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
if (col < band_start || col >= band_end) {
output_ptr[index] = Scalar();
} else {
output_ptr[index] = input_ptr[index];
}
}
} else {
const int matrix_size = m * n;
CUDA_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;
const int batch = index / matrix_size;
const int transpose_index = batch * matrix_size + n * col + row;
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
const int band_end =
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
if (col < band_start || col >= band_end) {
output_ptr[transpose_index] = Scalar();
} else {
output_ptr[transpose_index] = input_ptr[index];
}
}
}
}
template <typename Scalar>
struct MatrixBandPartFunctor<GPUDevice, Scalar> {
void operator()(OpKernelContext* context, const GPUDevice& device,
int num_lower_diags, int num_upper_diags, bool transpose,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
using CudaType = typename CUDAComplexT<Scalar>::type;
const int batch_size = input.dimension(0);
const int m = input.dimension(1);
const int n = input.dimension(2);
const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
if (transpose) {
MatrixBandPartKernel<true>
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
num_upper_diags, input_ptr, output_ptr);
} else {
MatrixBandPartKernel<false>
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
num_upper_diags, input_ptr, output_ptr);
}
}
};
#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
} // end namespace tensorflow
#undef DEFINE_GPU_SPEC
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -6898,6 +6898,41 @@ op {
}
}
}
op {
name: "DecodeRaw"
input_arg {
name: "bytes"
type: DT_STRING
}
output_arg {
name: "output"
type_attr: "out_type"
}
attr {
name: "out_type"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_UINT16
type: DT_UINT8
type: DT_INT16
type: DT_INT8
type: DT_INT64
}
}
}
attr {
name: "little_endian"
type: "bool"
default_value {
b: true
}
}
}
op {
name: "DecodeWav"
input_arg {

View File

@ -399,7 +399,7 @@ matrix is assumed to be zero and not accessed.
`rhs` is a tensor of shape `[..., M, K]`.
The output is a tensor of shape `[..., M, K]`. If `adjoint` is
`True` then the innermost matrices in output` satisfy matrix equations
`True` then the innermost matrices in `output` satisfy matrix equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `False` then the strictly then the innermost matrices in
`output` satisfy matrix equations

Some files were not shown because too many files have changed in this diff Show More