Branch 168917534 (#13077)
* Use HLO name, rather than pointer address, for profile counter name. This removes a source of nondeterminism in IR generation. PiperOrigin-RevId: 168779489 * Eager gradient tape doesn't keep tensors alive. PiperOrigin-RevId: 168782341 * Add missing back-quote PiperOrigin-RevId: 168785422 * Add in a comment that I forgot to add to a previous commit; NFC. PiperOrigin-RevId: 168786760 * Update ops-related pbtxt files. PiperOrigin-RevId: 168787665 * Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 168788051 * Fix typo "comptuation" (computation) PiperOrigin-RevId: 168799777 * Fix a bug in export GTFlow model to shared format with sparse float split PiperOrigin-RevId: 168802503 * Add signature def utility functions for inspection of input and output types and shapes. PiperOrigin-RevId: 168820997 * Apply const qualifiers whenever appropriate. PiperOrigin-RevId: 168824461 * TFE: Clearer error message when enable_eager_execution is called more than once PiperOrigin-RevId: 168834147 * [tf.contrib.data] Add colocation constraints between Iterator and Datasets. This restores the colocation behavior that was present when Dataset objects were passed as DT_RESOURCE tensors, and avoids the (currently not supported) case where TensorFlow may attempt to split the dataset pipeline across devices. PiperOrigin-RevId: 168841061 * Optimize C++ kernels for the matrix_band_part op, which is used in various ops operating on triangular or banded matrices: * Add benchmark for matrix_band_part. * Implement simple optimized CUDA kernel instead of calling Eigen generator. * Parallelize CPU kernel for matrix_band_part. * Support on-the-fly transposition in the underlying functors (to be used for future QR op in followup). Benchmarks: First column is of the form {device}_{shape}_{num_lower,num_upper} Test case Before After Speedup cpu_(10,16,16)_(-1,-1) 5.6505e-05 6.2108e-05 -9.92% cpu_(10,16,16)_(-1,0) 0.00010848 0.00010908 -0.55% cpu_(10,16,16)_(0,-1) 0.0001055 0.00011396 -8.02% cpu_(10,16,16)_(2,2) 0.000108 0.00011706 -8.39% cpu_(10,101,101)_(-1,-1) 0.00013697 6.0558e-05 +55.79% cpu_(10,101,101)_(-1,0) 0.00054002 0.00017703 +67.22% cpu_(10,101,101)_(0,-1) 0.00051188 0.00017607 +65.60% cpu_(10,101,101)_(2,2) 0.00050449 0.00016904 +66.49% cpu_(10,256,256)_(-1,-1) 0.00032043 5.6028e-05 +82.51% cpu_(10,256,256)_(-1,0) 0.001335 0.0004015 +69.93% cpu_(10,256,256)_(0,-1) 0.0013521 0.00038862 +71.26% cpu_(10,256,256)_(2,2) 0.001269 0.00039959 +68.51% cpu_(10,1000,1000)_(-1,-1) 0.0090729 6.3419e-05 +99.30% cpu_(10,1000,1000)_(-1,0) 0.01712 0.0047594 +72.20% cpu_(10,1000,1000)_(0,-1) 0.016647 0.0046474 +72.08% cpu_(10,1000,1000)_(2,2) 0.012737 0.0041161 +67.68% cpu_(10,1024,1024)_(-1,-1) 0.0093709 5.8889e-05 +99.37% cpu_(10,1024,1024)_(-1,0) 0.017075 0.0051999 +69.55% cpu_(10,1024,1024)_(0,-1) 0.016867 0.004617 +72.63% cpu_(10,1024,1024)_(2,2) 0.013191 0.003759 +71.50% cpu_(10,2048,2048)_(-1,-1) 0.028427 6.2466e-05 +99.78% cpu_(10,2048,2048)_(-1,0) 0.048134 0.017642 +63.35% cpu_(10,2048,2048)_(0,-1) 0.048773 0.017558 +64.00% cpu_(10,2048,2048)_(2,2) 0.036153 0.015452 +57.26% cpu_(10,10,4,4)_(-1,-1) 5.8055e-05 5.8055e-05 +0.00% cpu_(10,10,4,4)_(-1,0) 0.00015557 0.0001564 -0.54% cpu_(10,10,4,4)_(0,-1) 0.00015855 0.00015199 +4.14% cpu_(10,10,4,4)_(2,2) 0.00016379 0.00018096 -10.48% cpu_(10,10,10,10)_(-1,-1) 6.0558e-05 6.0558e-05 +0.00% cpu_(10,10,10,10)_(-1,0) 0.000368 0.00038695 -5.15% cpu_(10,10,10,10)_(0,-1) 0.00036263 0.00038612 -6.48% cpu_(10,10,10,10)_(2,2) 0.00038648 0.00042963 -11.17% cpu_(10,10,16,16)_(-1,-1) 6.9022e-05 5.7578e-05 +16.58% cpu_(10,10,16,16)_(-1,0) 0.0005815 0.0001874 +67.77% cpu_(10,10,16,16)_(0,-1) 0.00059354 0.0001924 +67.58% cpu_(10,10,16,16)_(2,2) 0.00062239 0.00019097 +69.32% cpu_(10,10,101,101)_(-1,-1) 0.00014806 6.2823e-05 +57.57% cpu_(10,10,101,101)_(-1,0) 0.0039785 0.00078249 +80.33% cpu_(10,10,101,101)_(0,-1) 0.0040585 0.00076556 +81.14% cpu_(10,10,101,101)_(2,2) 0.0039514 0.00077307 +80.44% cpu_(10,10,256,256)_(-1,-1) 0.0026824 6.0558e-05 +97.74% cpu_(10,10,256,256)_(-1,0) 0.017269 0.0031619 +81.69% cpu_(10,10,256,256)_(0,-1) 0.020287 0.0030774 +84.83% cpu_(10,10,256,256)_(2,2) 0.011919 0.0026599 +77.68% cpu_(10,10,1000,1000)_(-1,-1) 0.065783 5.6982e-05 +99.91% cpu_(10,10,1000,1000)_(-1,0) 0.1361 0.054533 +59.93% cpu_(10,10,1000,1000)_(0,-1) 0.1397 0.053405 +61.77% cpu_(10,10,1000,1000)_(2,2) 0.10173 0.048561 +52.26% cpu_(10,10,1024,1024)_(-1,-1) 0.066231 7.5579e-05 +99.89% cpu_(10,10,1024,1024)_(-1,0) 0.13615 0.059931 +55.98% cpu_(10,10,1024,1024)_(0,-1) 0.13745 0.064931 +52.76% cpu_(10,10,1024,1024)_(2,2) 0.10493 0.054258 +48.29% cpu_(10,10,2048,2048)_(-1,-1) 0.23487 6.6042e-05 +99.97% cpu_(10,10,2048,2048)_(-1,0) 0.41014 0.24283 +40.79% cpu_(10,10,2048,2048)_(0,-1) 0.43621 0.26393 +39.49% cpu_(10,10,2048,2048)_(2,2) 0.29919 0.22302 +25.46% gpu_(10,16,16)_(-1,-1) 0.00010753 0.00010753 +0.00% gpu_(10,16,16)_(-1,0) 0.00011253 0.00012445 -10.59% gpu_(10,16,16)_(0,-1) 0.00012493 0.00013399 -7.25% gpu_(10,16,16)_(2,2) 0.000108 0.00011754 -8.83% gpu_(10,101,101)_(-1,-1) 0.00011849 8.7976e-05 +25.75% gpu_(10,101,101)_(-1,0) 0.00012743 0.00012243 +3.93% gpu_(10,101,101)_(0,-1) 0.00012958 0.00012362 +4.60% gpu_(10,101,101)_(2,2) 0.00011504 0.00011504 +0.00% gpu_(10,256,256)_(-1,-1) 0.00013447 9.7513e-05 +27.48% gpu_(10,256,256)_(-1,0) 0.00018752 0.00014746 +21.36% gpu_(10,256,256)_(0,-1) 0.00017798 0.00016904 +5.02% gpu_(10,256,256)_(2,2) 0.0001514 0.00013697 +9.53% gpu_(10,1000,1000)_(-1,-1) 0.0005095 9.8586e-05 +80.65% gpu_(10,1000,1000)_(-1,0) 0.00088501 0.00056589 +36.06% gpu_(10,1000,1000)_(0,-1) 0.00090456 0.00055242 +38.93% gpu_(10,1000,1000)_(2,2) 0.00080955 0.00049639 +38.68% gpu_(10,1024,1024)_(-1,-1) 0.00050902 9.7036e-05 +80.94% gpu_(10,1024,1024)_(-1,0) 0.00098789 0.00058246 +41.04% gpu_(10,1024,1024)_(0,-1) 0.001 0.00059545 +40.46% gpu_(10,1024,1024)_(2,2) 0.00082254 0.00049961 +39.26% gpu_(10,2048,2048)_(-1,-1) 0.001495 9.8944e-05 +93.38% gpu_(10,2048,2048)_(-1,0) 0.003535 0.0017736 +49.83% gpu_(10,2048,2048)_(0,-1) 0.0034965 0.0017921 +48.75% gpu_(10,2048,2048)_(2,2) 0.0027704 0.0015399 +44.41% gpu_(10,10,4,4)_(-1,-1) 0.00011086 9.1076e-05 +17.85% gpu_(10,10,4,4)_(-1,0) 0.0001235 0.00013411 -8.59% gpu_(10,10,4,4)_(0,-1) 0.00011849 0.0001204 -1.61% gpu_(10,10,4,4)_(2,2) 0.00010896 0.00013256 -21.66% gpu_(10,10,10,10)_(-1,-1) 0.00010657 9.5844e-05 +10.07% gpu_(10,10,10,10)_(-1,0) 0.00011754 0.00013602 -15.72% gpu_(10,10,10,10)_(0,-1) 0.00011909 0.00012004 -0.80% gpu_(10,10,10,10)_(2,2) 0.00013196 0.00011349 +14.00% gpu_(10,10,16,16)_(-1,-1) 0.00012898 0.00010705 +17.01% gpu_(10,10,16,16)_(-1,0) 0.00014353 0.00012338 +14.04% gpu_(10,10,16,16)_(0,-1) 0.00011599 0.00012493 -7.71% gpu_(10,10,16,16)_(2,2) 0.00011539 0.00011349 +1.65% gpu_(10,10,101,101)_(-1,-1) 0.00014699 0.00010252 +30.25% gpu_(10,10,101,101)_(-1,0) 0.0002141 0.00015497 +27.62% gpu_(10,10,101,101)_(0,-1) 0.0002017 0.00015843 +21.45% gpu_(10,10,101,101)_(2,2) 0.00018394 0.00015402 +16.27% gpu_(10,10,256,256)_(-1,-1) 0.00032747 9.0003e-05 +72.52% gpu_(10,10,256,256)_(-1,0) 0.00074494 0.00040746 +45.30% gpu_(10,10,256,256)_(0,-1) 0.00072503 0.00042391 +41.53% gpu_(10,10,256,256)_(2,2) 0.00061846 0.00038004 +38.55% gpu_(10,10,1000,1000)_(-1,-1) 0.0032645 0.00010896 +96.66% gpu_(10,10,1000,1000)_(-1,0) 0.007543 0.0038971 +48.34% gpu_(10,10,1000,1000)_(0,-1) 0.006058 0.0039405 +34.95% gpu_(10,10,1000,1000)_(2,2) 0.005198 0.003448 +33.67% gpu_(10,10,1024,1024)_(-1,-1) 0.0034155 9.1434e-05 +97.32% gpu_(10,10,1024,1024)_(-1,0) 0.007099 0.004158 +41.43% gpu_(10,10,1024,1024)_(0,-1) 0.006843 0.003849 +43.75% gpu_(10,10,1024,1024)_(2,2) 0.005506 0.0031376 +43.02% gpu_(10,10,2048,2048)_(-1,-1) 0.013119 0.00010097 +99.23% gpu_(10,10,2048,2048)_(-1,0) 0.028533 0.015175 +46.81% gpu_(10,10,2048,2048)_(0,-1) 0.028458 0.014926 +47.55% gpu_(10,10,2048,2048)_(2,2) 0.022175 0.011797 +46.80% PiperOrigin-RevId: 168849471 * * dataset_ops.read_batch_features() now discards keys for keyed Dataset. * dataset_ops.read_batch_features() ignores unnecessary repeat() when num_repeat == 1. PiperOrigin-RevId: 168855155 * Migrate TFGAN eval to opensource. PiperOrigin-RevId: 168855880 * [XLA] Remove superfluous locking from xla::ComputationBuilder. The class is thread compatible, not thread-safe. It is illegal to call non-const methods of the class concurrently. So the mutex is pointless. Also mark a couple of accessors const. PiperOrigin-RevId: 168857132 * Add ConvertGraphDefToXla to convert from GraphDef to xla::Computation. The main logic is simply refactored from tfcompile, with some minor cleanups along the way. PiperOrigin-RevId: 168857174 * Bugfix to tf.contrib.seq2seq beam_search_ops: GPU edge case of seq_len == 0. PiperOrigin-RevId: 168862288 * [tf.contrib.data] Add `batch_and_drop_remainder` transformation. This transformation, which is designed for use with `Dataset.apply()`, acts like the default of behavior of `tf.train.batch()`, which will truncate a finite input source if its number of elements is not an exact multiple of the batch size. A benefit of using this transformation is that it gives a statically known shape to the output elements, because they are all exactly `batch_size` in the 0th dimension. PiperOrigin-RevId: 168863148 * Minor renaming from tfcompile.Config to tf2xla.Config in comments. PiperOrigin-RevId: 168863860 * Certain ops don't need eager gradients to keep their inputs / outputs alive. PiperOrigin-RevId: 168864350 * [XLA] Add S64 while loop test. PiperOrigin-RevId: 168865653 * tfdbg: fix a bug in list_inputs and list_outputs wherein a tensor name like "x:1" fails to be processed because it were not converted to the node name ("x" in this example) first. Also simplify analyzer_cli_test.py a little through a new helper function. PiperOrigin-RevId: 168867948 * Adds multi_label_head in tf.contrib.estimator PiperOrigin-RevId: 168873313 * Script that generates __init__.py files based on tf_api_names annotations. PiperOrigin-RevId: 168878737 * Fixing the build command. PiperOrigin-RevId: 168881605 * Make sure all checked threads are joined before they are terminated. PiperOrigin-RevId: 168884294 * Output metrics in train mode for multihead. This is to be consistent with other heads who output the metric tensors in train mode. Outputting the metric tensors allow us for example to plot the metrics on the training set (and compare them to the metircs on the eval set). PiperOrigin-RevId: 168884726 * Automated g4 rollback of changelist 168458634 PiperOrigin-RevId: 168887778 * Adds DNNEstimator to tf.contrib.estimator. PiperOrigin-RevId: 168887825 * [tf.contrib.data] Expose `tf.contrib.data.batch_and_drop_remainder()`. PiperOrigin-RevId: 168888592 * disabling timeout test in opensource build PiperOrigin-RevId: 168890483 * Add ops that perform color transforms (including changing value, saturation and hue) in YIQ space. PiperOrigin-RevId: 168897736 * Update the minimum requirement of espsilon for batch norm. PiperOrigin-RevId: 168897907 * Adding support for capture-by-value. PiperOrigin-RevId: 168903482 * disabling failing tsan test PiperOrigin-RevId: 168903876 * disable asan for test timeout PiperOrigin-RevId: 168903999 * Internal change. PiperOrigin-RevId: 168910187 * Fix broken test: tensorflow/contrib/eager/python:datasets_test PiperOrigin-RevId: 168914742 * [XLA:CPU] Implement map fusion. PiperOrigin-RevId: 168915358 * Merge changes from github. END_PUBLIC I also integrated #13073 by hand to make TAP happy. --- Commit92362d0f0
authored by Skye Wanderman-Milne<skyewm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add WhileContext class and add plumbing for creating them. This change introduces WhileContext, which stores information about a while loop and will be used in future changes to generate while loop gradient graphs. Exit nodes in a while loop now have a pointer to their associated WhileContext. This will be used to retrieve the context for a given loop. This change adds an optional parameter to BuildWhileLoop() to create a WhileContext for the while loop (currently this is always true, but gradients will generate while loops without associated contexts). This change also adds a as-yet-unused option to BuildWhileLoop() to return the predicate output. PiperOrigin-RevId: 168562303 --- Commita4f6e7c1a
authored by RJ Ryan<rjryan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add mel-scale conversion matrix support to tf.contrib.signal. PiperOrigin-RevId: 168560255 --- Commitb00b6d23c
authored by Henry Tan<henrytan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix a segmentation fault caused by invalid log directory in InternalFlush(). PiperOrigin-RevId: 168557063 --- Commit2bc7a155a
authored by Yong Tang<yong.tang.github@outlook.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Add uint16 support for tf.decode_raw (#12719) * Add uint16 support for tf.decode_raw This fix tries to address the request raised in 10124 where uint16 support for tf.decode_raw is needed. tf.decode_raw already support half, float32, float64, int8, int16, int32, int64, uint8. And uint16 was not supported. This fix adds uint16 support for tf.decode_raw. This fix fixes 10124. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix test failure caused by uint16 support of decode_raw and add unit tests. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit009285c09
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Remove benchmark for TensorShapeOld. PiperOrigin-RevId: 168551108 --- Commitdc1eda8a6
authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Fix CHECK-failure crash if a non-tuple was passed to GetTupleElement. PiperOrigin-RevId: 168550703 --- Commit010922ed9
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 168549989 --- Commitc8a6131e9
authored by Mark Daoust<markdaoust@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: make `tf.sets` examples executable Fixes #12969 PiperOrigin-RevId: 168549712 --- Commitbece65c6f
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use a map instead of a vector of Children() in the BeamEntry. The assumption is that since the entries are sparse (they are all populated, but most are never Active()), using the map will save memory and make iterating over the Children() more efficient. PiperOrigin-RevId: 168548814 --- Commit0d5ab82ce
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 168548642 --- Commit3331c574b
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Implementing gradients for tf.image.resize_bicubic. PiperOrigin-RevId: 168547412 --- Commit4982ef0fa
authored by Martin Wicke<wicke@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add the ability to warn only once if deprecated functionality is used, and make that the default. PiperOrigin-RevId: 168545655 --- Commit99423416a
authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Make shape inference error messages for the While HLO more readable. Build the error lazily. PiperOrigin-RevId: 168531083 --- Commitd10374e45
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Discard some unneccessary logging commands. PiperOrigin-RevId: 168500721 --- Commit83cbabb85
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix wrong format of logging message. PiperOrigin-RevId: 168497373 --- Commiteec4f1b3a
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 168494944 --- Commit69301f352
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 168494220 --- Commit9d56f419c
authored by Mingxing Tan<tanmingxing@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add crop_and_decode_jpeg_op that combines the crop and decode for better performance. PiperOrigin-RevId: 168493125 --- Commit48ddf64d0
authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Make large params test only run in opt builds. PiperOrigin-RevId: 168491913 --- Commit11d3ac29d
authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for large numbers of parameter / return values and while loops. PiperOrigin-RevId: 168487225 --- Commit3cd6bdef5
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added test cases on R4 slice. PiperOrigin-RevId: 168482049 --- Commit46a81b5c3
authored by Jacques Pienaar<jpienaar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add cast S64 to F32 test. PiperOrigin-RevId: 168473650 --- Commit59bdf598d
authored by Derek Murray<mrry@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add an automatically-generated "tensorflow.python.platform.build_info" script. The motivation for this script is to provide better tools for diagnosing load-time errors (such as the ones that plague the Windows build due to DLL issues). Note that the script is intended to be self-contained, so that it is possible to import it without loading the entire TensorFlow runtime. This generated script currently contains a single symbol, `is_cuda_build`, which records whether the build has GPU support or not. PiperOrigin-RevId: 168471034 --- Commitc3b86347f
authored by Olivia Nordquist<nolivia@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: reenabling tests that are passing PiperOrigin-RevId: 168466361 --- Commitc728665ec
authored by Henry Tan<henrytan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add const qualifiers whenever appropriate. PiperOrigin-RevId: 168465926 --- Commitbf96fcd13
authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use the scalar cache in MeanGrad. PiperOrigin-RevId: 168462267 --- Commit1cada9ea2
authored by Olivia Nordquist<nolivia@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: reenabling test that passed after 100 runs w/o timing out PiperOrigin-RevId: 168458634 --- Commit00c865566
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Generate error (instead of segfault) when trying to copy string tensor to GPU in EagerTensor constructor. PiperOrigin-RevId: 168457320 --- Commit655f26fc7
authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Resurrects autograd-free eager gradients. PiperOrigin-RevId: 168448557 --- Commit8f37f3002
authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Cleanups to handling of arguments during XLA compilation: * combine resource kinds in XlaCompiler::Argument::Kind, use a separate XlaResource::Kind field to distinguish different kinds of resource. * merge XlaContext::HandleOrConstant and XlaExpression, which were almost identical. * remove XlaContext::Argument; instead, build XlaExpressions directly from XlaCompiler and add them to the XlaContext. PiperOrigin-RevId: 168439341 --- Commit7f5346a80
authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Reduce cmake log mess. * Echo off for the .bat scripts. * TF cmake: disable warnings in some of the patched projects (gif,jpeg,lmdb). PiperOrigin-RevId: 168432070 --- Commit2ad85aa4d
authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use xla/tests:xla_internal_test_main for all tests under tf/compiler/xla and remove any main() definitions in tests. This enables use of flags in all tests. PiperOrigin-RevId: 168424796 --- Commitcd377811d
authored by Henry Tan<henrytan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Comment and error message consistency cleanup. PiperOrigin-RevId: 168422582 --- Commit7c19b82af
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update tf.sparse_reset_shape so that when shrinking the shape of an empty sparse tensor, the result has a shape of all zeros. PiperOrigin-RevId: 168419639 --- Commitfcacb40d4
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: FirstReadyManager for scheduling nodes in VirtualScheduler. The current FIFOManager may yield inefficient scheduling; _Recv pushed to the FIFO blocks other nodes that can run before _Recv due to the node order in FIFO. FirstReadyManager picks a node with the earliest time_ready in the queue, avoiding this problem. Also, fixed VirtualPlacer to properly set device when Node's device name does not include job name and to set GPU:0 as default device. PiperOrigin-RevId: 168418455 --- Commit7e47624f5
authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: eager: Initial support for iteration over tf.contrib.data.Dataset objects. TODO: - Support function-valued operation attributes in eager (Required for MapDataset, FilterDataset etc. which encode the per-element computation in a TensorFlow function) PiperOrigin-RevId: 168418250 --- Commitb0a397fce
authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: eager: Remove unnecessary TFE_Context argument to TFE_OpSetDevice. PiperOrigin-RevId: 168417999 --- Commit86211d554
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Graph transform to flatten atrous (dilated) convolutions (i.e., a sequence of SpaceToBatchND-Conv-BatchToSpaceND ops) to a regular Conv op with upsampled filters. PiperOrigin-RevId: 168414124 --- Commit3438981ca
authored by David G. Andersen<dga@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Apply exported symbol filtering to the c++ API analogously to what is filtered for the C API. Fixes bug reported in comments on #1924 PiperOrigin-RevId: 168413719 --- Commit7e023d865
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA:CPU] Remove code from parallel CPU backend outlining that was causing unnecessary copies to be inserted, and which is no longer necessary since we added co-located buffer support for kCall. *) All bitcast copy is no longer necessary as CopyInsertion will insert copies at the root of the computation for a parameter which is live-out. *) Copy if root does not define buffer no longer necessary because colocated assignment looks at points-to set of root instruction. PiperOrigin-RevId: 168412076 --- Commit5da4df92c
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Simplify some code in grappler_item_builder.cc, no change in logic. PiperOrigin-RevId: 168409110 --- Commit82ec6241a
authored by drpngx<drpngx@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Add six and numpy imports --- Commit9c4ce2452
authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add flag parsing to more tests in xla/service specifically those which build HLO graphs. This enables, for example, dumping of the graphs with --xla_generate_hlo_graph. Also remove some superfluous tensorflow test_main dependencies. PiperOrigin-RevId: 168406746 --- Commitd4efa695c
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Relax the feed_nodes collection check, which triggers a false positive in some modes where the feed node collection is auto-generated. Keep it as a warning to help correct user-provided feed node lists. PiperOrigin-RevId: 168396408 --- Commitcbc46a856
authored by Changming Sun<chasun@microsoft.com> Committed by gunan<gunan@google.com>: Add a missing template explicit instantiation of SetZeroFunctor (#12791) --- Commit7bb08f5bf
authored by Kevin Slagle<kjslag@gmail.com> Committed by drpngx<drpngx@users.noreply.github.com>: fix ExponentialMovingAverage documentation so that ExponentialMovingAverage.apply is evaluated within control_dependencies (#12987) --- Commite6b011763
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Extend c++ gradient_checker to complex types. PiperOrigin-RevId: 168392949 --- Commit4086219a4
authored by Lyndon White<oxinabox@ucc.asn.au> Committed by drpngx<drpngx@users.noreply.github.com>: Correct minor typo in substr docs example (#12991) --- Commitf63aa7f49
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Migrate core TFGAN functions to opensource. PiperOrigin-RevId: 168391923 --- Commitbc6b60f1b
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix tuple_losses bug caused by Python bug. PiperOrigin-RevId: 168386341 --- Commit7a8c63da3
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Migrate `leaky_relu` to `nn_ops.py`. Will be used for TFGAN. PiperOrigin-RevId: 168386268 --- Commitf7ba16fdf
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Do not export from eval on train data steps. PiperOrigin-RevId: 168374021 --- Commit9b9e54b34
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Adding NCCL sum op, register all_sum gradient. Streamlining nccl test. PiperOrigin-RevId: 168347428 --- Commitbc300318e
authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update gemmlowp hash as the commit history seems to have changed in the repository. PiperOrigin-RevId: 168343607 --- Commit1e96d54d9
authored by gunan<gunan@google.com> Committed by GitHub<noreply@github.com>: Also accept non-k8 CPU types in build pip package. (#12975) * Also accept non-k8 CPU types in build pip package. Fixes #12735 * Make the script work with `set -e`. --- Commitc0a4c7ffc
authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Fix bug in ShapeUtil::ShapeIs that would lead to type inference errors. PiperOrigin-RevId: 168323589 --- Commit4af9be964
authored by Amy<amy@infosleuth.net> Committed by drpngx<drpngx@users.noreply.github.com>: support passing in a source url to the mnist read_data_sets function, to make it easier to use 'fashion mnist' etc. (#12983) --- Commit9f848734f
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Tweak layer a bit to be eager friendly. PiperOrigin-RevId: 168312865 --- Commit60f15462b
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Change conv_input_scale and side_input_scale from attributes to inputs for improved flexibility, in fused_conv2d_bias_activation op. PiperOrigin-RevId: 168311988 --- Commit4b4e10f9c
authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Adds dict support of eval metrics. PiperOrigin-RevId: 168310444 --- Commitab7f22de6
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op. PiperOrigin-RevId: 168300911 --- Commit3a98035fa
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Augment metadata output with source-line info, as before. PiperOrigin-RevId: 168292527 --- Commit349188152
authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Enable fused batch norm, which is 15-20% faster for training and inference. PiperOrigin-RevId: 168288154 --- Commit08587d45b
authored by Yuefeng Zhou<yuefengz@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added back persistent memory tracking in queue op. The new tracking logic has avoided the crash in previous implementation: the queue_ passed to CreateTypedQueue may be unreffed if the resource is already created by another resource op that shares the same resource name and type. PiperOrigin-RevId: 168284509 --- Commit733063d55
authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Fixing awkward wording. --- Commitc7ad6bfef
authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Removing accidental hash. --- Commit53dbc761a
authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Adding Windows self check script to docs. --- Commited1135994
authored by Andrew Harp<andrewharp@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add -latomic flag to benchmark_model target to fix Android x86 build. PiperOrigin-RevId: 168281337 --- Commitc0348bb55
authored by Anna R<annarev@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update tf_export.py to take constant name as an argument instead of a constant. PiperOrigin-RevId: 168280613 --- Commitc3d19e40a
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Cleanup training_ops to reduce code redudancy. PiperOrigin-RevId: 168280069 --- Commit123fb01ee
authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Set fused=False for batch norm, because the test assumes no bessel's correction. Fused=True would add bessel's correction to variance. PiperOrigin-RevId: 168274392 --- Commitf0e8c545e
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Switch resource variables from copy-on-read to copy-on-write. RELNOTES: Change the signature of (C++) GetInputTensorFromVariable in training_op_helpers to support new copy-on-write semenatics of resource variables. PiperOrigin-RevId: 168273249 --- Commit495cc8e47
authored by Yuan (Terry) Tang<terrytangyuan@users.noreply.github.com> Committed by drpngx<drpngx@users.noreply.github.com>: Minor wording change in timeseries module's README (#12938) * Minor wording change in timeseries module's README * Address comments --- Commitf13b876ed
authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Making the default build from source version 1.4.0dev. The whl files that are built will be 1.3.0devDDMMYYYY. --- Commit2356c0ff4
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day). PiperOrigin-RevId: 168259652 --- Commite15f4cae2
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Don't remove all aliases from linalg namespace. Get rid of redundant aliases. PiperOrigin-RevId: 168257658 --- Commitc58082642
authored by postBG<profile2697@gmail.com> Committed by drpngx<drpngx@users.noreply.github.com>: Fix minor typo in Programmers guide (#12965) * Fix minor typo in Programmers guide * change to "this" --- Commit509372c2e
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add a lot of operations' flops calculations PiperOrigin-RevId: 168256746 --- Commit80ed8afc0
authored by Francois Chollet<fchollet@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add Flatten to core layers. PiperOrigin-RevId: 168254118 --- Commita6223c01a
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix locking of variables in SparseProximalGradientDescent, AdagradDA, SparseAdagradDA. PiperOrigin-RevId: 168252530 --- Commitabde00830
authored by Olivia Nordquist<nolivia@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: adding InputTensor class for symmetry with OutputTensor PiperOrigin-RevId: 168250085 --- Commit0451032ca
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Fix variable naming style guide violation. PiperOrigin-RevId: 168245542 --- Commita202a5a94
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 168245371 --- Commitf93e354cb
authored by Derek Murray<mrry@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf.contrib.data] Switch backend Dataset representation to DT_VARIANT. This change introduces a new `DatasetWrapper` type that wraps a `DatasetBase*` and can be stored in a DT_VARIANT tensor. All Dataset ops now consume and produce DT_VARIANT instead of DT_RESOURCE, and the underlying implementation is simplified because the `DatasetWrapper` can be passed directly by value without using the `ResourceMgr`. PiperOrigin-RevId: 168240571 --- Commita4042cd2a
authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Introduces the placeholder for _TrainingExecutor, which serves the implementation of tf.estimator.train_and_evaluate. PiperOrigin-RevId: 168240151 --- Commit10ba148f7
authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Switch control_flow_ops library to use Resource variants of Stack operators, instead of deprecated Ref variants. PiperOrigin-RevId: 168234822 --- Commitca43fe82b
authored by Ali Yahya<alive@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: TFE: Improves the interfaces of tape.watch_variable() and implicit_grad(). tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles. implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface. PiperOrigin-RevId: 168232055 --- Commitb72862dfc
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: internal change PiperOrigin-RevId: 168225993 --- Commitda3280f4d
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Re-enable tsan for sdca_estimator_test. PiperOrigin-RevId: 168186374 --- Commitc936c1155
authored by Yifei Feng<yifeif@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix pip tests for contrib/gan. - Add *_impl.py so tests can still access removed symbols. - Add /python directory layer to make *_impy.py and __init__.py not in the same dir. PiperOrigin-RevId: 168161722 --- Commitce9a2b00f
authored by Toby Boyd<tobyboyd@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Performance guide update PiperOrigin-RevId: 168159289 --- Commit3bce4f9a0
authored by Shanqing Cai<cais@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: TFE: expose tfe.num_gpus() PiperOrigin-RevId: 168154345 --- Commit67a7cbc28
authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Changed the default eval throttle secs from 2 min to 10 mins. PiperOrigin-RevId: 168120323 --- Commit92bed178f
authored by Eugene Brevdo<ebrevdo@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Reduce cmake log mess. * Echo off for the .bat scripts. * TF cmake: disable warnings in some of the patched projects (gif,jpeg,lmdb). PiperOrigin-RevId: 168119914 --- Commit702d59582
authored by joshkyh<joshkyh@users.noreply.github.com> Committed by Yifei Feng<fengyifei2026@gmail.com>: Corrected hyperlink for audio training tutorial (#12923) --- Commit877c9deca
authored by Frank Chen<frankchn@gmail.com> Committed by Yifei Feng<fengyifei2026@gmail.com>: Reverse changeeb75ded6
so that internal tests will pass. (#12933) As support for int64 global steps is not ready in TPUs, I am reversing this change so that our internal performance and regression tests will pass. --- Commit665966438
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Re-enable grpc_session_test. PiperOrigin-RevId: 168078694 --- Commit405def792
authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Switch CallInliner to use CallGraph::VisitNodes. PiperOrigin-RevId: 168078645 --- Commitaba3466f1
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Exposes Head and factory methods in tf.contrib.estimator. PiperOrigin-RevId: 168071246 --- Commitb76565b39
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Some profiler fixes and cleanup. PiperOrigin-RevId: 168069346 --- Commit32ffc5a81
authored by Jonas<sauercrowd@users.noreply.github.com> Committed by Yifei Feng<fengyifei2026@gmail.com>: Just a dot in order to be consistent (#12919) added a dot to the `7` to make clear it's a float (like every other number) --- Commit0753b0c79
authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Scope the scalar cache in the context. PiperOrigin-RevId: 168065417 --- Commit48deb206b
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Migrate TFGAN features to third_party. PiperOrigin-RevId: 168060880 --- Commitd2ae1311f
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixing an issue in the BUILD file of the LSH ops. PiperOrigin-RevId: 168056645 --- Commit2f440eda4
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Expose NumpyReader for reading timeseries data. PiperOrigin-RevId: 168055838 --- Commitbe1916ce7
authored by Daniel Grazian<dgr@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added functionality to allow `SqlDataset` to interpret a database column as various numeric types, including several integer types and `dtypes.float64`. PiperOrigin-RevId: 168055827 --- Commitfa2000a0b
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Supporting nightly windows pip packages. PiperOrigin-RevId: 168054959 --- Commita263ea626
authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: eager: Treat eager tensors as constants during graph construction. Unless capturing is explicitly enabled. PiperOrigin-RevId: 168052675 --- Commit6e402d0d2
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make TODO a bit more specific. PiperOrigin-RevId: 168051381 --- Commitc779384bc
authored by Daniel Grazian<dgr@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added code example to the doc string for `SqlDataset`. PiperOrigin-RevId: 168049037 --- Commitff6dd474a
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use self._in_graph_mode consistently in ResourceVariable instead of sometimes getting it from the context. Also: fix formatting of a comment and use a more precise test to detect if initial_value is set. PiperOrigin-RevId: 168047258 --- Commitf331f528b
authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Removes "fast paths" which are not fast in eager mode. PiperOrigin-RevId: 168046278 --- Commit86f1713e5
authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Introduces TrainSpec and EvalSpec. PiperOrigin-RevId: 168040435 --- Commitc8b9e92f0
authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: eager: Move "register_function" to context.py This will allow function registration from other modules without having to import "function.py". (And besides, the function really does belong on the context). PiperOrigin-RevId: 168040411 --- Commit74137f994
authored by Shanqing Cai<cais@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix signed int overflow issue in tensor_id.cc When a node name has a long numeric suffix, e.g., "foo/y_0/gradient_debug_09684b60f2184c67b744721915034528" (as has happened with tfdbg GradientsDebugger), the parsing algorithm in ParseTensorName() may experience signed int overflow. Replacing the types with "unsigned int" resolves the issue. PiperOrigin-RevId: 168039195 --- Commit450c3b562
authored by Rohan Jain<rohanj@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now. PiperOrigin-RevId: 168038285 --- Commit82cc6529f
authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixes the wording about StopIteration. PiperOrigin-RevId: 168034451 --- Commitfb5588002
authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add a statement on install/index.md on what os are supported. PiperOrigin-RevId: 168032996 --- Commitf83f6b9ef
authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test. PiperOrigin-RevId: 168029345 --- Commit8988ae365
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 167916124 PiperOrigin-RevId: 168916710 * Update ops-related pbtxt files. PiperOrigin-RevId: 168917157 * Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 168917534
This commit is contained in:
parent
7a97dfc3ec
commit
e55574f282
10
README.md
10
README.md
@ -45,11 +45,11 @@ GPU packages on all platforms will arrive soon!
|
||||
|
||||
|
||||
**Individual whl files**
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
|
||||
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
||||
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
||||
|
||||
|
@ -283,7 +283,6 @@ filegroup(
|
||||
"//tensorflow/contrib/crf:all_files",
|
||||
"//tensorflow/contrib/cudnn_rnn:all_files",
|
||||
"//tensorflow/contrib/data:all_files",
|
||||
"//tensorflow/contrib/data/python/framework:all_files",
|
||||
"//tensorflow/contrib/data/python/kernel_tests:all_files",
|
||||
"//tensorflow/contrib/data/python/ops:all_files",
|
||||
"//tensorflow/contrib/data/python/util:all_files",
|
||||
@ -418,6 +417,7 @@ filegroup(
|
||||
"//tensorflow/python/profiler/internal:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
"//tensorflow/python/tools:all_files",
|
||||
"//tensorflow/tools/api/generator:all_files",
|
||||
"//tensorflow/tools/api/golden:all_files",
|
||||
"//tensorflow/tools/api/lib:all_files",
|
||||
"//tensorflow/tools/api/tests:all_files",
|
||||
|
@ -127,10 +127,11 @@ Status Conv2DGrad(const Scope& scope, const Operation& op,
|
||||
std::vector<int32> strides;
|
||||
bool use_cudnn_on_gpu;
|
||||
auto attrs = op.output(0).node()->attrs();
|
||||
GetNodeAttr(attrs, "data_format", &data_format);
|
||||
GetNodeAttr(attrs, "padding", &padding);
|
||||
GetNodeAttr(attrs, "strides", &strides);
|
||||
GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu);
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu",
|
||||
&use_cudnn_on_gpu));
|
||||
Conv2DBackpropInput::Attrs input_attrs;
|
||||
input_attrs.DataFormat(data_format);
|
||||
input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
|
||||
@ -157,10 +158,10 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
|
||||
std::vector<int32> strides;
|
||||
std::vector<int32> ksize;
|
||||
auto attrs = op.output(0).node()->attrs();
|
||||
GetNodeAttr(attrs, "data_format", &data_format);
|
||||
GetNodeAttr(attrs, "ksize", &ksize);
|
||||
GetNodeAttr(attrs, "padding", &padding);
|
||||
GetNodeAttr(attrs, "strides", &strides);
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
|
||||
internal::MaxPoolGrad::Attrs grad_attrs;
|
||||
grad_attrs.DataFormat(data_format);
|
||||
auto dx = internal::MaxPoolGrad(scope, op.input(0),
|
||||
@ -179,8 +180,8 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
|
||||
string data_format;
|
||||
string padding;
|
||||
auto attrs = op.output(0).node()->attrs();
|
||||
GetNodeAttr(attrs, "data_format", &data_format);
|
||||
GetNodeAttr(attrs, "padding", &padding);
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
MaxPoolGradV2::Attrs grad_attrs;
|
||||
grad_attrs.DataFormat(data_format);
|
||||
auto dx = MaxPoolGradV2(scope, op.input(0),
|
||||
|
@ -4,7 +4,6 @@ package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
||||
# Optional runtime utilities for use by code generated by tfcompile.
|
||||
@ -39,32 +38,24 @@ cc_library(
|
||||
deps = ["//tensorflow/core:test_main"],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "tfcompile_proto",
|
||||
srcs = ["tfcompile.proto"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_lib",
|
||||
srcs = [
|
||||
"codegen.cc",
|
||||
"compile.cc",
|
||||
"flags.cc",
|
||||
"tfcompile_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"codegen.h",
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
"tfcompile_util.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime", # needed by codegen to print aligned_buffer_bytes
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -82,7 +73,6 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
@ -99,18 +89,6 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tfcompile_util_test",
|
||||
srcs = ["tfcompile_util_test.cc"],
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "tfcompile",
|
||||
visibility = ["//visibility:public"],
|
||||
@ -123,7 +101,8 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -226,7 +205,11 @@ test_suite(
|
||||
tags = ["manual"],
|
||||
tests = [
|
||||
":benchmark_test",
|
||||
":codegen_test",
|
||||
":runtime_test",
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfunknownop2_test",
|
||||
":test_graph_tfunknownop3_test",
|
||||
":test_graph_tfunknownop_test",
|
||||
"//tensorflow/compiler/aot/tests:all_tests",
|
||||
],
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/str_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -35,6 +35,12 @@ namespace tfcompile {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsAlpha(char c) {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
|
||||
}
|
||||
|
||||
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
|
||||
|
||||
// Convert an XLA type into a C++ type.
|
||||
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
|
||||
switch (type) {
|
||||
@ -156,7 +162,7 @@ string RewriteWithName(const string& name, string code,
|
||||
}
|
||||
|
||||
// Generate methods for args (inputs).
|
||||
Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
||||
const CompileResult& compile_result, string* methods) {
|
||||
*methods += R"(
|
||||
void** args() { return args_; }
|
||||
@ -204,8 +210,8 @@ Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
}
|
||||
|
||||
// Generate methods for results (outputs).
|
||||
Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
string* methods) {
|
||||
Status GenResultMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShape& ps, string* methods) {
|
||||
if (ps.result().element_type() != xla::TUPLE) {
|
||||
// Non-tuple (i.e. single-result) case.
|
||||
if (config.fetch_size() != 1) {
|
||||
@ -285,11 +291,26 @@ Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
|
||||
for (const tf2xla::Feed& feed : config.feed()) {
|
||||
if (!feed.name().empty()) {
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
|
||||
}
|
||||
}
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
if (!fetch.name().empty()) {
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
|
||||
Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
|
||||
const CompileResult& compile_result, string* header) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
|
||||
const int64 result_index = compile_result.aot->result_buffer_index();
|
||||
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
|
||||
if (result_index < 0 || result_index > temp_sizes.size()) {
|
||||
@ -574,5 +595,29 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
|
||||
if (ident.empty()) {
|
||||
return errors::InvalidArgument("empty identifier: ", msg);
|
||||
}
|
||||
// Require that the identifier starts with a nondigit, and is composed of
|
||||
// nondigits and digits, as specified in section [2.11 Identifiers] of the
|
||||
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
|
||||
// defined as [0-9].
|
||||
//
|
||||
// Technically the standard also allows for `universal-character-name`, with a
|
||||
// table of allowed unicode ranges, as well as `other implementation-defined
|
||||
// characters`. We disallow those here to give better error messages, at the
|
||||
// expensive of being more restrictive than the standard.
|
||||
if (ident[0] != '_' && !IsAlpha(ident[0])) {
|
||||
return errors::InvalidArgument("illegal leading char: ", msg);
|
||||
}
|
||||
for (size_t pos = 1; pos < ident.size(); ++pos) {
|
||||
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
|
||||
return errors::InvalidArgument("illegal char: ", msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
@ -37,7 +39,7 @@ struct HeaderOpts {
|
||||
// GenerateHeader uses the meta-information from compile_result to generate a
|
||||
// C++ header giving access to the function in the generated object file. The
|
||||
// header includes API usage documentation.
|
||||
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
|
||||
Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
|
||||
const CompileResult& compile_result, string* header);
|
||||
|
||||
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
|
||||
@ -47,6 +49,10 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config,
|
||||
Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
std::vector<string>* namespaces);
|
||||
|
||||
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
|
||||
// appended to error messages.
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -29,6 +30,41 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
EXPECT_NE(Status::OK(), status);
|
||||
EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
|
||||
<< "expected error: " << status.error_message() << " to contain: " << str;
|
||||
}
|
||||
|
||||
TEST(ValidateCppIdent, Simple) {
|
||||
TF_EXPECT_OK(ValidateCppIdent("a", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
|
||||
// Make sure we didn't skip a valid letter or digit
|
||||
string ident;
|
||||
for (char c = 'a'; c <= 'z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = 'A'; c <= 'Z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = '0'; c <= '9'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
ident += "_";
|
||||
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
|
||||
|
||||
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
|
||||
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
}
|
||||
|
||||
class ParseCppClassTest : public ::testing::Test {
|
||||
protected:
|
||||
void ExpectOK(const string& cpp_class, const string& want_class_name,
|
||||
@ -91,13 +127,13 @@ TEST(GenerateHeader, Golden) {
|
||||
HeaderOpts opts;
|
||||
opts.class_name = "MyClass";
|
||||
opts.namespaces = {"foo", "bar"};
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("feed0");
|
||||
feed->set_name("myfeed");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("feed1");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("fetch0");
|
||||
fetch->set_name("myfetch");
|
||||
CompileResult compile_result;
|
||||
|
@ -15,326 +15,32 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
const char* const kArgOp = "_Arg";
|
||||
const char* const kRetvalOp = "_Retval";
|
||||
const char* const kFeedIdAttr = "_feed_id";
|
||||
const char* const kFetchIdAttr = "_fetch_id";
|
||||
const char* const kShapeAttr = "_shape";
|
||||
const char* const kDebugNameAttr = "_debug_name";
|
||||
|
||||
namespace {
|
||||
|
||||
Status DumpGraph(const MainFlags& flags, const string& name,
|
||||
const Graph& graph) {
|
||||
if (flags.debug_dir.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
|
||||
return WriteTextProto(Env::Default(), file, graph_def);
|
||||
}
|
||||
|
||||
typedef std::unordered_map<string, Node*> NodeMap;
|
||||
|
||||
// Each feed id identifies the positional output of some node, which may consist
|
||||
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
|
||||
// tensor with a placeholder. For each feed tensor, replaces all edges so they
|
||||
// point from a new _Arg node instead.
|
||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<Feed>& feeds,
|
||||
const std::unordered_map<string, string>& feed_remapping) {
|
||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||
const Feed& feed = feeds[arg_index];
|
||||
// All feeds have been replaced by placeholders.
|
||||
const int output_index = 0;
|
||||
|
||||
const string key = TensorIdToString(feed.id());
|
||||
const auto remap_it = feed_remapping.find(key);
|
||||
auto node_it = node_map.find(remap_it->second);
|
||||
if (node_it == node_map.end()) {
|
||||
// Strip off the aot_feed_#/ prefix.
|
||||
StringPiece name(remap_it->second);
|
||||
const auto index = name.find('/');
|
||||
if (index > 0) name.remove_prefix(index + 1);
|
||||
return errors::InvalidArgument(
|
||||
"Node is fed but not needed for fetching: ", name);
|
||||
}
|
||||
const Node* feed_node = node_it->second;
|
||||
|
||||
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
|
||||
// "_shape" attr if we can determine it. That way the graph will be
|
||||
// initialized with whatever shapes we can infer, while the user can still
|
||||
// explicitly specify or override them.
|
||||
Node* arg_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||
.Attr("T", BaseType(feed_node->output_type(output_index)))
|
||||
.Attr("index", arg_index)
|
||||
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
|
||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||
.Attr(kDebugNameAttr, feed.name())
|
||||
.Finalize(graph, &arg_node));
|
||||
|
||||
// Collects out-edges from the feed node that have a matching edge index;
|
||||
// these will be replaced with edges from the arg node instead.
|
||||
//
|
||||
// We must collect the edges first and process them in a second pass, since
|
||||
// removing the edge from the graph invalidates feed_node->out_edges.
|
||||
std::vector<const Edge*> feed_edges;
|
||||
for (const Edge* edge : feed_node->out_edges()) {
|
||||
if (edge->src_output() == output_index) {
|
||||
feed_edges.push_back(edge);
|
||||
}
|
||||
}
|
||||
for (const Edge* edge : feed_edges) {
|
||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Each fetch id identifies the positional output of some node. For each fetch
|
||||
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
|
||||
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<Fetch>& fetches,
|
||||
std::unordered_set<const Node*>* retval_nodes) {
|
||||
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
|
||||
const TensorId& id = fetches[ret_index].id();
|
||||
auto it = node_map.find(id.node_name());
|
||||
if (it == node_map.end()) {
|
||||
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
|
||||
}
|
||||
Node* fetch_node = it->second;
|
||||
if (id.output_index() >= fetch_node->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
|
||||
", output index should be < ",
|
||||
fetch_node->num_outputs());
|
||||
}
|
||||
// Connects fetch_node -> retval_node.
|
||||
Node* retval_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
|
||||
.Input(fetch_node, id.output_index())
|
||||
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
|
||||
.Attr("index", ret_index)
|
||||
.Attr(kFetchIdAttr, TensorIdToString(id))
|
||||
.Finalize(graph, &retval_node));
|
||||
retval_nodes->insert(retval_node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
|
||||
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
||||
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
||||
// execution to know the input and output args for the generated function.
|
||||
Status RewriteAndPruneGraph(
|
||||
Graph* graph, const Config& config,
|
||||
const std::unordered_map<string, string>& feed_remapping,
|
||||
const MainFlags& flags) {
|
||||
NodeMap node_map;
|
||||
for (Node* n : graph->nodes()) {
|
||||
node_map[n->name()] = n;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
|
||||
std::unordered_set<const Node*> retval_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
|
||||
PruneForReverseReachability(graph, retval_nodes);
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
std::set<string> missing_feeds, missing_fetches;
|
||||
for (const Feed& feed : config.feed()) {
|
||||
missing_feeds.insert(TensorIdToString(feed.id()));
|
||||
}
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
missing_fetches.insert(TensorIdToString(fetch.id()));
|
||||
}
|
||||
for (const Node* n : graph->op_nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
string feed_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
|
||||
if (missing_feeds.erase(feed_id) == 0) {
|
||||
return errors::Aborted(kArgOp,
|
||||
" node found with unknown feed id: ", feed_id);
|
||||
}
|
||||
} else if (n->type_string() == kRetvalOp) {
|
||||
string fetch_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
|
||||
if (missing_fetches.erase(fetch_id) == 0) {
|
||||
return errors::Aborted(kRetvalOp,
|
||||
" node found with unknown fetch id: ", fetch_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!missing_feeds.empty() || !missing_fetches.empty()) {
|
||||
return errors::Aborted(
|
||||
"Post graph-pruning",
|
||||
", missing feeds: ", str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ", str_util::Join(missing_fetches, ", "));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
|
||||
// sanity-checking to ensure the index and type attributes of each node are
|
||||
// initialized correctly.
|
||||
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
|
||||
std::map<int, Node*> indexed_arg_nodes;
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
auto insert_result = indexed_arg_nodes.insert({index, n});
|
||||
if (!insert_result.second) {
|
||||
const Node* dup = insert_result.first->second;
|
||||
return errors::InvalidArgument(
|
||||
"Multiple ", kArgOp, " nodes with index ", index, ", ",
|
||||
n->DebugString(), " and ", dup->DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_nodes->clear();
|
||||
for (const auto& index_node : indexed_arg_nodes) {
|
||||
if (index_node.first != arg_nodes->size()) {
|
||||
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
|
||||
arg_nodes->size(), ", but got index ",
|
||||
index_node.first);
|
||||
}
|
||||
arg_nodes->push_back(index_node.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Fills in xla_args from the corresponding _Arg nodes in the graph.
|
||||
Status CreateXlaArgs(const Graph& graph,
|
||||
std::vector<XlaCompiler::Argument>* xla_args) {
|
||||
std::vector<Node*> arg_nodes;
|
||||
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
|
||||
for (const Node* node : arg_nodes) {
|
||||
XlaCompiler::Argument arg;
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
||||
xla_args->push_back(arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the TensorFlow graph into an XLA computation, by executing the
|
||||
// graph symbolically, with each op building up the XLA HLO.
|
||||
Status ConvertGraphToXla(xla::CompileOnlyClient* client,
|
||||
std::unique_ptr<Graph> graph,
|
||||
xla::Computation* computation, bool* has_context_arg) {
|
||||
// Create a device and context to convert the graph into an XLA computation.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
// Populate the context with args from the graph.
|
||||
for (Node* node : graph->nodes()) {
|
||||
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
|
||||
}
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||
|
||||
// Compile the graph into an XLA computation.
|
||||
XlaCompiler::Options compiler_options;
|
||||
compiler_options.client = client;
|
||||
DeviceType device_type(DEVICE_CPU_XLA_JIT);
|
||||
compiler_options.device_type = &device_type;
|
||||
compiler_options.flib_def = &graph->flib_def();
|
||||
compiler_options.graph_def_version = graph->versions().producer();
|
||||
compiler_options.allow_cpu_custom_calls = true;
|
||||
XlaCompiler compiler(compiler_options);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
||||
"tfcompile", std::move(graph),
|
||||
xla_args, &result));
|
||||
*has_context_arg = result.requires_runtime_context;
|
||||
*computation = std::move(*result.computation);
|
||||
|
||||
int num_const_results = 0;
|
||||
for (int i = 0; i < result.outputs.size(); ++i) {
|
||||
// Ending up with const results (i.e. output args) is an error, since it
|
||||
// means that one or more fetches that the user specified will be dropped
|
||||
// from the generated function. It's most likely a configuration error,
|
||||
// since the user shouldn't be asking for output args that end up as consts.
|
||||
//
|
||||
// TODO(toddw): Provide a way for the user to access const output args,
|
||||
// e.g. perhaps hard-coded into the header, or somehow copied into the
|
||||
// output buffers.
|
||||
if (result.outputs[i].is_constant) {
|
||||
++num_const_results;
|
||||
LOG(ERROR) << "ConstRetVal index:" << i
|
||||
<< " value:" << result.outputs[i].constant_value.DebugString();
|
||||
}
|
||||
}
|
||||
if (num_const_results > 0) {
|
||||
return errors::Unimplemented(
|
||||
"Conversion from TensorFlow graph to XLA resulted in ",
|
||||
num_const_results,
|
||||
" constant results. The configuration of "
|
||||
"the output args (i.e. fetch ids) is probably wrong.");
|
||||
}
|
||||
if (computation->IsNull()) {
|
||||
return errors::Aborted(
|
||||
"Conversion from TensorFlow graph to XLA resulted in an empty "
|
||||
"computation.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
Status CompileXla(xla::CompileOnlyClient* client,
|
||||
const xla::Computation& computation,
|
||||
@ -376,41 +82,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
const MainFlags& flags, std::unique_ptr<Graph>* graph) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
||||
std::unique_ptr<Graph> g(new Graph(flib_def));
|
||||
|
||||
// Replace references to fed tensors with references to newly added
|
||||
// placeholders.
|
||||
GraphDef first_copy_def = graph_def;
|
||||
|
||||
// Maps from name:port of a feed to the name:port of the placeholder to use.
|
||||
std::unordered_map<string, string> feed_remapping;
|
||||
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
|
||||
&feed_remapping, &first_copy_def));
|
||||
|
||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||
// filtered out.
|
||||
GraphDef second_copy_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
|
||||
&second_copy_def, *g->op_registry(), 0 /*node_offset*/));
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
||||
second_copy_def, g.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
|
||||
*graph = std::move(g);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
CompileResult* compile_result) {
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result) {
|
||||
// Converts the graph into an XLA computation, and compiles the
|
||||
// computation.
|
||||
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
|
||||
@ -421,8 +94,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
|
||||
.ValueOrDie();
|
||||
xla::Computation computation;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
|
||||
&compile_result->has_context_arg));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
|
||||
&computation,
|
||||
&compile_result->has_context_arg));
|
||||
if (!flags.debug_dir.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
|
||||
computation.Snapshot());
|
||||
|
@ -18,46 +18,16 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Constants for op types and attribute names.
|
||||
extern const char* const kArgOp;
|
||||
extern const char* const kRetvalOp;
|
||||
extern const char* const kFeedIdAttr;
|
||||
extern const char* const kFetchIdAttr;
|
||||
extern const char* const kShapeAttr;
|
||||
extern const char* const kDebugNameAttr;
|
||||
|
||||
// InitGraph creates a graph based on the graph_def, that may then be compiled
|
||||
// by CompileGraph.
|
||||
//
|
||||
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
|
||||
// and outputs of the function that will be compiled. Each feed id causes a new
|
||||
// _Arg node to be created, where we first collect all existing edges pointing
|
||||
// from the named node's output index, and then rewrite them to point from that
|
||||
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
|
||||
// with a new edge pointing from the named node's output index to that _Retval
|
||||
// node. All _Retval nodes also point to a special CompileExpressions node,
|
||||
// used internally to finish the compilation.
|
||||
//
|
||||
// The rewritten graph is then pruned to only contain the portion necessary to
|
||||
// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
|
||||
// for debugging.
|
||||
Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
const MainFlags& flags, std::unique_ptr<Graph>* graph);
|
||||
|
||||
// CompileResult describes the output of CompileGraph, where the object file
|
||||
// data and meta-information is available in aot.
|
||||
struct CompileResult {
|
||||
@ -69,20 +39,12 @@ struct CompileResult {
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
};
|
||||
|
||||
// CompileGraph compiles the graph into an object file containing a function
|
||||
// CompileGraph compiles the graph_def into an object file containing a function
|
||||
// that performs the graph operations.
|
||||
//
|
||||
// The graph must have _Arg and _Retval nodes representing the function inputs
|
||||
// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
|
||||
// value=TensorShape) representing the static shape of that input, and every
|
||||
// _Retval node must point to a CompileExpressions node.
|
||||
//
|
||||
// Typically InitGraph is called to perform this initialization, followed by
|
||||
// full specification of the shape attributes.
|
||||
//
|
||||
// The XLA compilation options are specified in the flags.
|
||||
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
CompileResult* result);
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -13,9 +13,11 @@ test_suite(
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfadd_with_ckpt_saver_test",
|
||||
":test_graph_tfadd_with_ckpt_test",
|
||||
":test_graph_tffunction_test",
|
||||
":test_graph_tfgather_test",
|
||||
":test_graph_tfmatmul_test",
|
||||
":test_graph_tfmatmulandadd_test",
|
||||
":test_graph_tfsplits_test",
|
||||
":tfcompile_test",
|
||||
],
|
||||
)
|
||||
@ -90,6 +92,15 @@ tf_library(
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfgather",
|
||||
testonly = 1,
|
||||
@ -117,15 +128,6 @@ tf_library(
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfsplits",
|
||||
testonly = 1,
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "params" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x" }
|
||||
shape {
|
||||
|
@ -41,7 +41,7 @@ def tf_library(name, graph, config,
|
||||
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
|
||||
is expected to be in the human-readable proto text format, otherwise it is
|
||||
expected to be in the proto binary format.
|
||||
config: File containing tensorflow.tfcompile.Config proto. If the file ends
|
||||
config: File containing tensorflow.tf2xla.Config proto. If the file ends
|
||||
in '.pbtxt' it is expected to be in the human-readable proto text format,
|
||||
otherwise it is expected to be in the proto binary format.
|
||||
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -54,8 +54,7 @@ const char kUsageHeader[] =
|
||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& kind, const string& fname,
|
||||
protobuf::Message* proto) {
|
||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (StringPiece(fname).ends_with(".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
@ -63,23 +62,17 @@ Status ReadProtoFile(const string& kind, const string& fname,
|
||||
}
|
||||
}
|
||||
|
||||
void ParseTensorId(const string& name, TensorId* id) {
|
||||
const std::pair<StringPiece, int> name_index = ParseTensorName(name);
|
||||
id->set_node_name(name_index.first.ToString());
|
||||
id->set_output_index(name_index.second);
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
// Process config.
|
||||
Config config;
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << str_util::Join(nodes, ",");
|
||||
@ -91,12 +84,9 @@ Status Main(const MainFlags& flags) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph));
|
||||
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result));
|
||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
|
@ -21,6 +21,40 @@ package(
|
||||
)
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
|
||||
xla_proto_library(
|
||||
name = "tf2xla_proto",
|
||||
srcs = ["tf2xla.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla",
|
||||
srcs = ["tf2xla.cc"],
|
||||
hdrs = ["tf2xla.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
":dump_graph",
|
||||
":tf2xla_proto",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_compiler",
|
||||
@ -96,6 +130,51 @@ cc_library(
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla_util",
|
||||
srcs = ["tf2xla_util.cc"],
|
||||
hdrs = ["tf2xla_util.h"],
|
||||
deps = [
|
||||
":tf2xla_proto",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tf2xla_util_test",
|
||||
srcs = ["tf2xla_util_test.cc"],
|
||||
deps = [
|
||||
":tf2xla_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tf2xla_test",
|
||||
srcs = ["tf2xla_test.cc"],
|
||||
deps = [
|
||||
":tf2xla",
|
||||
":tf2xla_proto",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "xla_compiler_test",
|
||||
srcs = ["xla_compiler_test.cc"],
|
||||
|
370
tensorflow/compiler/tf2xla/tf2xla.cc
Normal file
370
tensorflow/compiler/tf2xla/tf2xla.cc
Normal file
@ -0,0 +1,370 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kArgOp = "_Arg";
|
||||
const char* const kRetvalOp = "_Retval";
|
||||
const char* const kFeedIdAttr = "_feed_id";
|
||||
const char* const kFetchIdAttr = "_fetch_id";
|
||||
const char* const kShapeAttr = "_shape";
|
||||
const char* const kDebugNameAttr = "_debug_name";
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::unordered_map<string, Node*> NodeMap;
|
||||
|
||||
// Each feed id identifies the positional output of some node, which may consist
|
||||
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
|
||||
// tensor with a placeholder. For each feed tensor, replaces all edges so they
|
||||
// point from a new _Arg node instead.
|
||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
|
||||
const std::unordered_map<string, string>& feed_remapping) {
|
||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||
const tf2xla::Feed& feed = feeds[arg_index];
|
||||
// All feeds have been replaced by placeholders.
|
||||
const int output_index = 0;
|
||||
|
||||
const string key = TensorIdToString(feed.id());
|
||||
const auto remap_it = feed_remapping.find(key);
|
||||
auto node_it = node_map.find(remap_it->second);
|
||||
if (node_it == node_map.end()) {
|
||||
// Strip off the aot_feed_#/ prefix.
|
||||
StringPiece name(remap_it->second);
|
||||
const auto index = name.find('/');
|
||||
if (index > 0) name.remove_prefix(index + 1);
|
||||
return errors::InvalidArgument(
|
||||
"Node is fed but not needed for fetching: ", name);
|
||||
}
|
||||
const Node* feed_node = node_it->second;
|
||||
|
||||
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
|
||||
// "_shape" attr if we can determine it. That way the graph will be
|
||||
// initialized with whatever shapes we can infer, while the user can still
|
||||
// explicitly specify or override them.
|
||||
Node* arg_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||
.Attr("T", BaseType(feed_node->output_type(output_index)))
|
||||
.Attr("index", arg_index)
|
||||
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
|
||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||
.Attr(kDebugNameAttr, feed.name())
|
||||
.Finalize(graph, &arg_node));
|
||||
|
||||
// Collects out-edges from the feed node that have a matching edge index;
|
||||
// these will be replaced with edges from the arg node instead.
|
||||
//
|
||||
// We must collect the edges first and process them in a second pass, since
|
||||
// removing the edge from the graph invalidates feed_node->out_edges.
|
||||
std::vector<const Edge*> feed_edges;
|
||||
for (const Edge* edge : feed_node->out_edges()) {
|
||||
if (edge->src_output() == output_index) {
|
||||
feed_edges.push_back(edge);
|
||||
}
|
||||
}
|
||||
for (const Edge* edge : feed_edges) {
|
||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Each fetch id identifies the positional output of some node. For each fetch
|
||||
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
|
||||
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
|
||||
std::unordered_set<const Node*>* retval_nodes) {
|
||||
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
|
||||
const tf2xla::TensorId& id = fetches[ret_index].id();
|
||||
auto it = node_map.find(id.node_name());
|
||||
if (it == node_map.end()) {
|
||||
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
|
||||
}
|
||||
Node* fetch_node = it->second;
|
||||
if (id.output_index() >= fetch_node->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
|
||||
", output index should be < ",
|
||||
fetch_node->num_outputs());
|
||||
}
|
||||
// Connects fetch_node -> retval_node.
|
||||
Node* retval_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
|
||||
.Input(fetch_node, id.output_index())
|
||||
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
|
||||
.Attr("index", ret_index)
|
||||
.Attr(kFetchIdAttr, TensorIdToString(id))
|
||||
.Finalize(graph, &retval_node));
|
||||
retval_nodes->insert(retval_node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
|
||||
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
||||
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
||||
// execution to know the input and output args for the generated function.
|
||||
Status RewriteAndPruneGraph(
|
||||
Graph* graph, const tf2xla::Config& config,
|
||||
const std::unordered_map<string, string>& feed_remapping) {
|
||||
NodeMap node_map;
|
||||
for (Node* n : graph->nodes()) {
|
||||
node_map[n->name()] = n;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
|
||||
std::unordered_set<const Node*> retval_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
VLOG(2) << "Post rewrite: "
|
||||
<< dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
||||
PruneForReverseReachability(graph, retval_nodes);
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
VLOG(2) << "Post prune: "
|
||||
<< dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph);
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
std::set<string> missing_feeds, missing_fetches;
|
||||
for (const tf2xla::Feed& feed : config.feed()) {
|
||||
missing_feeds.insert(TensorIdToString(feed.id()));
|
||||
}
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
missing_fetches.insert(TensorIdToString(fetch.id()));
|
||||
}
|
||||
for (const Node* n : graph->op_nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
string feed_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
|
||||
if (missing_feeds.erase(feed_id) == 0) {
|
||||
return errors::Aborted(kArgOp,
|
||||
" node found with unknown feed id: ", feed_id);
|
||||
}
|
||||
} else if (n->type_string() == kRetvalOp) {
|
||||
string fetch_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
|
||||
if (missing_fetches.erase(fetch_id) == 0) {
|
||||
return errors::Aborted(kRetvalOp,
|
||||
" node found with unknown fetch id: ", fetch_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!missing_feeds.empty() || !missing_fetches.empty()) {
|
||||
return errors::Aborted(
|
||||
"Post graph-pruning",
|
||||
", missing feeds: ", str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ", str_util::Join(missing_fetches, ", "));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
|
||||
// sanity-checking to ensure the index and type attributes of each node are
|
||||
// initialized correctly.
|
||||
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
|
||||
std::map<int, Node*> indexed_arg_nodes;
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
auto insert_result = indexed_arg_nodes.insert({index, n});
|
||||
if (!insert_result.second) {
|
||||
const Node* dup = insert_result.first->second;
|
||||
return errors::InvalidArgument(
|
||||
"Multiple ", kArgOp, " nodes with index ", index, ", ",
|
||||
n->DebugString(), " and ", dup->DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_nodes->clear();
|
||||
for (const auto& index_node : indexed_arg_nodes) {
|
||||
if (index_node.first != arg_nodes->size()) {
|
||||
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
|
||||
arg_nodes->size(), ", but got index ",
|
||||
index_node.first);
|
||||
}
|
||||
arg_nodes->push_back(index_node.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Fills in xla_args from the corresponding _Arg nodes in the graph.
|
||||
Status CreateXlaArgs(const Graph& graph,
|
||||
std::vector<XlaCompiler::Argument>* xla_args) {
|
||||
std::vector<Node*> arg_nodes;
|
||||
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
|
||||
for (const Node* node : arg_nodes) {
|
||||
XlaCompiler::Argument arg;
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
||||
xla_args->push_back(arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the TensorFlow graph into an XLA computation, by executing the
|
||||
// graph symbolically, with each op building up the XLA HLO.
|
||||
Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
|
||||
xla::Computation* computation,
|
||||
bool* requires_runtime_context) {
|
||||
// Create a device and context to convert the graph into an XLA computation.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
// Populate the context with args from the graph.
|
||||
for (Node* node : graph->nodes()) {
|
||||
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
|
||||
}
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||
|
||||
// Compile the graph into an XLA computation.
|
||||
XlaCompiler::Options compiler_options;
|
||||
compiler_options.client = client;
|
||||
DeviceType device_type(DEVICE_CPU_XLA_JIT);
|
||||
compiler_options.device_type = &device_type;
|
||||
compiler_options.flib_def = &graph->flib_def();
|
||||
compiler_options.graph_def_version = graph->versions().producer();
|
||||
compiler_options.allow_cpu_custom_calls = true;
|
||||
XlaCompiler compiler(compiler_options);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
||||
"tfcompile", std::move(graph),
|
||||
xla_args, &result));
|
||||
*requires_runtime_context = result.requires_runtime_context;
|
||||
*computation = std::move(*result.computation);
|
||||
|
||||
int num_const_results = 0;
|
||||
for (int i = 0; i < result.outputs.size(); ++i) {
|
||||
// Ending up with const results (i.e. output args) is an error, since it
|
||||
// means that one or more fetches that the user specified will be dropped
|
||||
// from the generated function. It's most likely a configuration error,
|
||||
// since the user shouldn't be asking for output args that end up as consts.
|
||||
//
|
||||
// TODO(toddw): Provide a way for the user to access const output args,
|
||||
// e.g. perhaps hard-coded into the header, or somehow copied into the
|
||||
// output buffers.
|
||||
if (result.outputs[i].is_constant) {
|
||||
++num_const_results;
|
||||
LOG(ERROR) << "ConstRetVal index:" << i
|
||||
<< " value:" << result.outputs[i].constant_value.DebugString();
|
||||
}
|
||||
}
|
||||
if (num_const_results > 0) {
|
||||
return errors::Unimplemented(
|
||||
"Conversion from TensorFlow graph to XLA resulted in ",
|
||||
num_const_results,
|
||||
" constant results. The configuration of "
|
||||
"the output args (i.e. fetch ids) is probably wrong.");
|
||||
}
|
||||
if (computation->IsNull()) {
|
||||
return errors::Aborted(
|
||||
"Conversion from TensorFlow graph to XLA resulted in an empty "
|
||||
"computation.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// InitGraph creates a graph based on the graph_def, that may then be converted
|
||||
// to an xla::Computation via ConvertGraphToXla.
|
||||
//
|
||||
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
|
||||
// and outputs of the function that will be compiled. Each feed id causes a new
|
||||
// _Arg node to be created, where we first collect all existing edges pointing
|
||||
// from the named node's output index, and then rewrite them to point from that
|
||||
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
|
||||
// with a new edge pointing from the named node's output index to that _Retval
|
||||
// node.
|
||||
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
||||
std::unique_ptr<Graph> g(new Graph(flib_def));
|
||||
|
||||
// Replace references to fed tensors with references to newly added
|
||||
// placeholders.
|
||||
GraphDef first_copy_def = graph_def;
|
||||
|
||||
// Maps from name:port of a feed to the name:port of the placeholder to use.
|
||||
std::unordered_map<string, string> feed_remapping;
|
||||
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
|
||||
&feed_remapping, &first_copy_def));
|
||||
|
||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||
// filtered out.
|
||||
GraphDef second_copy_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
|
||||
&second_copy_def, *g->op_registry(), /*node_offset=*/0));
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
||||
second_copy_def, g.get()));
|
||||
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
|
||||
*graph = std::move(g);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
||||
const tf2xla::Config& config, xla::Client* client,
|
||||
xla::Computation* computation,
|
||||
bool* requires_runtime_context) {
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation,
|
||||
requires_runtime_context));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
43
tensorflow/compiler/tf2xla/tf2xla.h
Normal file
43
tensorflow/compiler/tf2xla/tf2xla.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts a tensorflow::GraphDef into an xla::Computation. The given `config`
|
||||
// specifies the portion of the graph to convert, via feeds and fetches. Each
|
||||
// feed is a positional input argument for the generated computation, while each
|
||||
// fetch is a positional output argument.
|
||||
//
|
||||
// The computation is built in the context of the given `client`, which may
|
||||
// subsequently be used to compile or execute the computation.
|
||||
//
|
||||
// If `requires_runtime_context` is filled with true, this indicates the last
|
||||
// argument of the computation is XlaLocalRuntimeContext*.
|
||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
||||
const tf2xla::Config& config, xla::Client* client,
|
||||
xla::Computation* computation,
|
||||
bool* requires_runtime_context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
@ -1,10 +1,10 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tfcompile;
|
||||
package tensorflow.tf2xla;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "CompileProtos";
|
||||
option java_outer_classname = "Tf2XlaProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.tfcompile";
|
||||
option java_package = "org.tensorflow.tf2xla";
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
@ -19,32 +19,32 @@ message TensorId {
|
||||
};
|
||||
|
||||
// Feed represents a single feed tensor in the graph, which corresponds to an
|
||||
// input argument for the generated function.
|
||||
// input argument for the generated computation.
|
||||
message Feed {
|
||||
TensorId id = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
string name = 3; // Optional name for generated code.
|
||||
|
||||
// Optional data type. This is not normally required, as the graph itself
|
||||
// contains this information. However, if the node being fed is an op that
|
||||
// is not linked into the tfcompile binary, then the type cannot be inferred
|
||||
// from the node; in this case, the type should be set here.
|
||||
// contains this information. However, if the node being fed is an op that is
|
||||
// not linked into the binary, then the type cannot be inferred from the node;
|
||||
// in this case, the type should be set here.
|
||||
DataType type = 4;
|
||||
};
|
||||
|
||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||
// output argument for the generated function.
|
||||
// output argument for the generated computation.
|
||||
message Fetch {
|
||||
TensorId id = 1;
|
||||
string name = 2; // Optional name for generated code.
|
||||
};
|
||||
|
||||
// Config represents configuration information for tfcompile.
|
||||
// Config represents configuration information for tf2xla conversion.
|
||||
message Config {
|
||||
// Each feed is a positional input argument for the generated function. The
|
||||
// order of each entry matches the order of each input argument.
|
||||
// Each feed is a positional input argument for the generated computation.
|
||||
// The order of each entry matches the order of each input argument.
|
||||
repeated Feed feed = 1;
|
||||
// Each fetch is a positional output argument for the generated function. The
|
||||
// order of each entry matches the order of each output argument.
|
||||
// Each fetch is a positional output argument for the generated computation.
|
||||
// The order of each entry matches the order of each output argument.
|
||||
repeated Fetch fetch = 2;
|
||||
};
|
99
tensorflow/compiler/tf2xla/tf2xla_test.cc
Normal file
99
tensorflow/compiler/tf2xla/tf2xla_test.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AttrValue TypeAttrValue(DataType type) {
|
||||
AttrValue attr_value;
|
||||
SetAttrValue(type, &attr_value);
|
||||
return attr_value;
|
||||
}
|
||||
|
||||
GraphDef SumGraph() {
|
||||
GraphDef graph_def;
|
||||
NodeDef* x = graph_def.add_node();
|
||||
x->set_name("x");
|
||||
x->set_op("Placeholder");
|
||||
(*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
|
||||
NodeDef* y = graph_def.add_node();
|
||||
y->set_name("y");
|
||||
y->set_op("Placeholder");
|
||||
(*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
|
||||
NodeDef* sum = graph_def.add_node();
|
||||
sum->set_name("sum");
|
||||
sum->set_op("Add");
|
||||
sum->add_input("x");
|
||||
sum->add_input("y");
|
||||
(*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32);
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
tf2xla::Config SumConfig() {
|
||||
tf2xla::Config config;
|
||||
config.add_feed()->mutable_id()->set_node_name("x");
|
||||
config.add_feed()->mutable_id()->set_node_name("y");
|
||||
config.add_fetch()->mutable_id()->set_node_name("sum");
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(ConvertGraphDefToXla, Sum) {
|
||||
GraphDef graph_def = SumGraph();
|
||||
tf2xla::Config config = SumConfig();
|
||||
|
||||
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
xla::Computation computation;
|
||||
bool requires_runtime_context;
|
||||
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation,
|
||||
&requires_runtime_context));
|
||||
ASSERT_FALSE(requires_runtime_context);
|
||||
|
||||
// Set up arguments.
|
||||
auto x_literal = xla::Literal::CreateR0<int32>(10);
|
||||
auto y_literal = xla::Literal::CreateR0<int32>(32);
|
||||
auto x_global_or = client->TransferToServer(*x_literal);
|
||||
auto y_global_or = client->TransferToServer(*y_literal);
|
||||
TF_EXPECT_OK(x_global_or.status());
|
||||
TF_EXPECT_OK(y_global_or.status());
|
||||
std::unique_ptr<xla::GlobalData> x_global =
|
||||
std::move(x_global_or.ValueOrDie());
|
||||
std::unique_ptr<xla::GlobalData> y_global =
|
||||
std::move(y_global_or.ValueOrDie());
|
||||
|
||||
// Execute and check result.
|
||||
auto result_or =
|
||||
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
|
||||
TF_EXPECT_OK(result_or.status());
|
||||
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
|
||||
EXPECT_EQ("42", result->ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -29,21 +29,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsAlpha(char c) {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
|
||||
}
|
||||
|
||||
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
|
||||
|
||||
Status ValidateTensorId(const TensorId& id) {
|
||||
Status ValidateTensorId(const tf2xla::TensorId& id) {
|
||||
if (id.node_name().empty()) {
|
||||
return errors::InvalidArgument("TensorId node_name must be non-empty");
|
||||
}
|
||||
@ -53,10 +45,9 @@ Status ValidateTensorId(const TensorId& id) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateFeedFetchName(const string& kind, const string& name,
|
||||
std::set<string>* names) {
|
||||
Status CheckNameDuplicates(const string& kind, const string& name,
|
||||
std::set<string>* names) {
|
||||
if (!name.empty()) {
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
|
||||
if (!names->insert(name).second) {
|
||||
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
|
||||
}
|
||||
@ -80,42 +71,18 @@ Status CheckFeedFetchNameConflicts(const string& kind,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
|
||||
if (ident.empty()) {
|
||||
return errors::InvalidArgument("empty identifier: ", msg);
|
||||
}
|
||||
// Require that the identifier starts with a nondigit, and is composed of
|
||||
// nondigits and digits, as specified in section [2.11 Identifiers] of the
|
||||
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
|
||||
// defined as [0-9].
|
||||
//
|
||||
// Technically the standard also allows for `universal-character-name`, with a
|
||||
// table of allowed unicode ranges, as well as `other implementation-defined
|
||||
// characters`. We disallow those here to give better error messages, at the
|
||||
// expensive of being more restrictive than the standard.
|
||||
if (ident[0] != '_' && !IsAlpha(ident[0])) {
|
||||
return errors::InvalidArgument("illegal leading char: ", msg);
|
||||
}
|
||||
for (size_t pos = 1; pos < ident.size(); ++pos) {
|
||||
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
|
||||
return errors::InvalidArgument("illegal char: ", msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateConfig(const Config& config) {
|
||||
Status ValidateConfig(const tf2xla::Config& config) {
|
||||
std::set<string> names;
|
||||
for (const Feed& feed : config.feed()) {
|
||||
for (const tf2xla::Feed& feed : config.feed()) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
|
||||
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
|
||||
TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
|
||||
names.clear();
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
|
||||
TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
|
||||
if (config.feed().empty() || config.fetch().empty()) {
|
||||
@ -125,10 +92,10 @@ Status ValidateConfig(const Config& config) {
|
||||
}
|
||||
|
||||
Status AddPlaceholdersForFeeds(
|
||||
const Config& config, const OpRegistryInterface* op_registry,
|
||||
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
|
||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
|
||||
struct PlaceholderInfo {
|
||||
const Feed* feed = nullptr; // point to Feed in <config>.
|
||||
const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
|
||||
string placeholder_name;
|
||||
DataType data_type = DT_INVALID;
|
||||
};
|
||||
@ -137,9 +104,9 @@ Status AddPlaceholdersForFeeds(
|
||||
// when creating placeholders (genrules want deterministic output).
|
||||
std::map<string, PlaceholderInfo> placeholder_info;
|
||||
for (int i = 0; i < config.feed_size(); ++i) {
|
||||
const Feed* feed = &config.feed(i);
|
||||
const tf2xla::Feed* feed = &config.feed(i);
|
||||
const string name_port = TensorIdToString(feed->id());
|
||||
auto& info = placeholder_info[name_port];
|
||||
PlaceholderInfo& info = placeholder_info[name_port];
|
||||
info.feed = feed;
|
||||
info.placeholder_name = strings::StrCat(
|
||||
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
|
||||
@ -153,7 +120,7 @@ Status AddPlaceholdersForFeeds(
|
||||
}
|
||||
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
||||
PlaceholderInfo& info = it->second;
|
||||
const TensorId& feed_id = info.feed->id();
|
||||
const tf2xla::TensorId& feed_id = info.feed->id();
|
||||
|
||||
// Find the existing node and determine data type.
|
||||
auto node_it = name_to_node.find(feed_id.node_name());
|
||||
@ -214,16 +181,16 @@ Status AddPlaceholdersForFeeds(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
|
||||
GraphDef* out) {
|
||||
*out = in;
|
||||
out->clear_node();
|
||||
|
||||
// Tensors needed for feeding.
|
||||
std::set<std::pair<string, int>> feed_tensors;
|
||||
for (const auto& feed_config : config.feed()) {
|
||||
feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
|
||||
feed_config.id().output_index()));
|
||||
for (const tf2xla::Feed& feed : config.feed()) {
|
||||
feed_tensors.insert(
|
||||
std::make_pair(feed.id().node_name(), feed.id().output_index()));
|
||||
}
|
||||
|
||||
// Maps node name to reachability.
|
||||
@ -279,9 +246,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string TensorIdToString(const TensorId& id) {
|
||||
string TensorIdToString(const tf2xla::TensorId& id) {
|
||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
@ -13,26 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
|
||||
// appended to error messages.
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
|
||||
// ValidateConfig returns OK iff config is valid.
|
||||
Status ValidateConfig(const Config& config);
|
||||
Status ValidateConfig(const tf2xla::Config& config);
|
||||
|
||||
// Modifies <graph_def> to include placeholders for each fed tensor, and
|
||||
// update references to the fed tensors to refer to the placeholders.
|
||||
@ -40,18 +34,17 @@ Status ValidateConfig(const Config& config);
|
||||
// (except where their input edges are modified by the replacement of other
|
||||
// feeds).
|
||||
Status AddPlaceholdersForFeeds(
|
||||
const Config& config, const OpRegistryInterface* op_registry,
|
||||
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
|
||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
|
||||
|
||||
// Returns in <out> a copy of <in>, pruned to only include fetches from
|
||||
// <config>.
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
|
||||
GraphDef* out);
|
||||
|
||||
// Returns node:port for the given <id>.
|
||||
string TensorIdToString(const TensorId& id);
|
||||
string TensorIdToString(const tf2xla::TensorId& id);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
@ -32,45 +31,16 @@ void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
<< "expected error: " << status.error_message() << " to contain: " << str;
|
||||
}
|
||||
|
||||
TEST(ValidateCppIdent, Simple) {
|
||||
TF_EXPECT_OK(ValidateCppIdent("a", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
|
||||
// Make sure we didn't skip a valid letter or digit
|
||||
string ident;
|
||||
for (char c = 'a'; c <= 'z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = 'A'; c <= 'Z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = '0'; c <= '9'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
ident += "_";
|
||||
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
|
||||
|
||||
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
|
||||
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, Good) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->mutable_id()->set_output_index(123);
|
||||
feed->set_name("foo_debug");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("bar");
|
||||
feed->mutable_id()->set_output_index(0);
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("baz");
|
||||
fetch->mutable_id()->set_output_index(456);
|
||||
fetch->set_name("baz_debug");
|
||||
@ -81,62 +51,62 @@ TEST(ValidateConfig, Good) {
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadEmpty) {
|
||||
Config config;
|
||||
tf2xla::Config config;
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadNoFeed) {
|
||||
Config config;
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("foo");
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadNoFetch) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFeedNodeName) {
|
||||
Config config;
|
||||
tf2xla::Config config;
|
||||
config.add_feed();
|
||||
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFeedOutputIndex) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->mutable_id()->set_output_index(-1);
|
||||
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFetchNodeName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
config.add_fetch();
|
||||
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFetchOutputIndex) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->mutable_id()->set_output_index(-1);
|
||||
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, DuplicateFeedName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->set_name("dup");
|
||||
feed = config.add_feed();
|
||||
@ -146,10 +116,10 @@ TEST(ValidateConfig, DuplicateFeedName) {
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, DuplicateFetchName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->set_name("dup");
|
||||
fetch = config.add_fetch();
|
||||
@ -159,8 +129,8 @@ TEST(ValidateConfig, DuplicateFetchName) {
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, ConflictingFeedName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->set_name("conflict");
|
||||
feed = config.add_feed();
|
||||
@ -170,10 +140,10 @@ TEST(ValidateConfig, ConflictingFeedName) {
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, ConflictingFetchName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
tf2xla::Config config;
|
||||
tf2xla::Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
tf2xla::Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->set_name("conflict");
|
||||
fetch = config.add_fetch();
|
||||
@ -182,8 +152,8 @@ TEST(ValidateConfig, ConflictingFetchName) {
|
||||
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
|
||||
}
|
||||
|
||||
static Config FetchesConfig(std::vector<string> fetches) {
|
||||
Config config;
|
||||
static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
|
||||
tf2xla::Config config;
|
||||
for (const auto& fetch_node_name : fetches) {
|
||||
auto* fetch = config.add_fetch();
|
||||
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
|
||||
@ -242,5 +212,4 @@ TEST(PruneGraphDefInto, Basic) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
@ -1703,7 +1703,6 @@ StatusOr<Computation> ComputationBuilder::Build() {
|
||||
}
|
||||
|
||||
void ComputationBuilder::AddOpMetadata(OpRequest* request) const {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
*request->mutable_metadata() = metadata_;
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stacktrace.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -57,10 +56,10 @@ class ComputationBuilder {
|
||||
~ComputationBuilder();
|
||||
|
||||
// Returns the client the builder was initialized with.
|
||||
Client* client() { return client_; }
|
||||
Client* client() const { return client_; }
|
||||
|
||||
// Returns the computation name.
|
||||
const string& name() { return name_; }
|
||||
const string& name() const { return name_; }
|
||||
|
||||
// Sets OpMetadata that will be added to all instructions until cleared.
|
||||
//
|
||||
@ -69,13 +68,11 @@ class ComputationBuilder {
|
||||
// instructions generated via this Computation Builder will have the same
|
||||
// OpMetadata attached until a call to ClearOpMetdata.
|
||||
void SetOpMetadata(const OpMetadata& metadata) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
metadata_ = metadata;
|
||||
}
|
||||
|
||||
// Clears the HloMetdata state.
|
||||
void ClearOpMetadata() {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
metadata_.Clear();
|
||||
}
|
||||
|
||||
@ -826,15 +823,12 @@ class ComputationBuilder {
|
||||
Client* client_;
|
||||
|
||||
// Mode bit that indicates whether to die when a first error is encountered.
|
||||
bool die_immediately_on_error_{false};
|
||||
|
||||
// Mutex to guard against concurrent access to metadata_.
|
||||
mutable tensorflow::mutex mutex_;
|
||||
bool die_immediately_on_error_ = false;
|
||||
|
||||
// The metadata to attach to each op. This is structured as a "modal"-like
|
||||
// operation, in order to simplify client code (and not sprinkle this metadata
|
||||
// throughout the TensorFlow op kernel implementations).
|
||||
OpMetadata metadata_ GUARDED_BY(mutex_);
|
||||
OpMetadata metadata_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
|
||||
};
|
||||
|
@ -180,15 +180,18 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "ir_emitter",
|
||||
srcs = ["ir_emitter.cc"],
|
||||
srcs = [
|
||||
"elemental_ir_emitter.cc",
|
||||
"ir_emitter.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"elemental_ir_emitter.h",
|
||||
"ir_emitter.h",
|
||||
],
|
||||
deps = [
|
||||
":cpu_options",
|
||||
":cpu_runtime",
|
||||
":dot_op_emitter",
|
||||
":elemental_ir_emitter",
|
||||
":ir_emission_utils",
|
||||
":simple_orc_jit",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -525,22 +528,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elemental_ir_emitter",
|
||||
srcs = ["elemental_ir_emitter.cc"],
|
||||
hdrs = ["elemental_ir_emitter.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"@llvm//:core",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ir_emission_utils",
|
||||
srcs = ["ir_emission_utils.cc"],
|
||||
|
@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Producer or consumer cannot be Map. Maps are technically elementwise but
|
||||
// of a slightly different form (call instead of a computation). These are not
|
||||
// yet supported in the CPU backend.
|
||||
if (producer->opcode() == HloOpcode::kMap ||
|
||||
consumer->opcode() == HloOpcode::kMap) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cost condition: not fuse (simple, expensive producers) and (consumers who
|
||||
// reuse operand elements).
|
||||
if (producer->opcode() != HloOpcode::kFusion &&
|
||||
|
@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest {
|
||||
std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
|
||||
expected_opcodes);
|
||||
}
|
||||
|
||||
HloComputation* CreateAdderToOne(HloModule* module) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* arg0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
|
||||
return module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
HloComputation* CreateMax(HloModule* module) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* arg0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
|
||||
HloInstruction* arg1 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "arg1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
|
||||
return module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
|
||||
@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
|
||||
HloOpcode::kParameter});
|
||||
}
|
||||
|
||||
TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
|
||||
auto module = CreateNewModule();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
|
||||
HloInstruction* exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
|
||||
builder.AddInstruction(HloInstruction::CreateMap(
|
||||
shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
RunFusionAndCheckOpcodesWereFused(
|
||||
module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
|
||||
}
|
||||
|
||||
TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
|
||||
auto module = CreateNewModule();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, shape, "param"));
|
||||
|
||||
HloInstruction* exp0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
|
||||
HloInstruction* exp1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateMap(
|
||||
shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
RunFusionAndCheckOpcodesWereFused(
|
||||
module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
|
||||
HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) const {
|
||||
if (hlo->opcode() == HloOpcode::kMap) {
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
std::vector<llvm::Value*> operands;
|
||||
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(i))(
|
||||
ElementwiseSourceIndex(index, *hlo, 0)));
|
||||
operands.push_back(operand_value);
|
||||
}
|
||||
return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
|
||||
hlo->to_apply(), operands,
|
||||
llvm_ir::IrName(hlo));
|
||||
};
|
||||
}
|
||||
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -29,12 +30,19 @@ namespace cpu {
|
||||
class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
public:
|
||||
CpuElementalIrEmitter(const HloModuleConfig& module_config,
|
||||
llvm::IRBuilder<>* ir_builder, llvm::Module* module)
|
||||
: ElementalIrEmitter(module_config, module, ir_builder) {}
|
||||
IrEmitter* ir_emitter, llvm::Module* module)
|
||||
: ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()),
|
||||
ir_emitter_(ir_emitter) {}
|
||||
|
||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) const override;
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||
|
||||
IrEmitter* ir_emitter_;
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
|
@ -136,6 +136,10 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
|
||||
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
|
||||
const bool single_threaded_eigen =
|
||||
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
|
||||
|
||||
// This is the point at which it is better to call into Eigen and shard the
|
||||
// dot across multiple worker threads. This is a rough estimate by running
|
||||
// a matmult benchmark on my local machine, and it can be tuned further.
|
||||
const int64 kMaxSingleThreadedFlops = 16 * 1024;
|
||||
|
||||
const int64 M = result_shape.dimensions(0);
|
||||
|
@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
parameter_arrays.push_back(GetIrArrayForOp(operand));
|
||||
}
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
|
||||
module_);
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
|
||||
@ -2737,14 +2736,10 @@ llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
prof_counter_idx = it->second;
|
||||
uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo);
|
||||
counter_name = tensorflow::strings::StrCat(
|
||||
"prof_counter_0x",
|
||||
tensorflow::strings::Hex(
|
||||
hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address))));
|
||||
counter_name = IrName("prof_counter", hlo->name());
|
||||
} else {
|
||||
prof_counter_idx = hlo_to_profile_idx_->size();
|
||||
counter_name = "prof_counter_computation";
|
||||
counter_name = "prof_counter.computation";
|
||||
}
|
||||
return ir_builder_.CreateGEP(GetProfileCountersArgument(),
|
||||
ir_builder_.getInt64(prof_counter_idx),
|
||||
@ -3180,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
|
||||
};
|
||||
}
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
|
||||
module_);
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
return EmitTargetElementLoop(
|
||||
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
|
||||
PrimitiveType return_type, HloComputation* computation,
|
||||
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
|
||||
llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
|
||||
std::vector<llvm::Value*> argument_addrs;
|
||||
for (auto argument : arguments) {
|
||||
llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
argument->getType(), "arg_addr", &ir_builder_);
|
||||
ir_builder_.CreateStore(argument, argument_addr);
|
||||
argument_addrs.push_back(argument_addr);
|
||||
}
|
||||
return EmitElementFunctionCall(llvm_function,
|
||||
ShapeUtil::MakeShape(return_type, {}),
|
||||
argument_addrs, name);
|
||||
}
|
||||
|
||||
unsigned TargetMachineFeatures::largest_register_size_in_bytes(
|
||||
llvm::Function* function) {
|
||||
auto itr = largest_register_size_in_bytes_.find(function);
|
||||
|
@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
bool is_top_level_computation,
|
||||
std::vector<const HloInstruction*>* instruction_order);
|
||||
|
||||
llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
|
||||
|
||||
// Emits a call to `computation` with scalar arguments `arguments`.
|
||||
StatusOr<llvm::Value*> EmitScalarCall(
|
||||
PrimitiveType return_type, HloComputation* computation,
|
||||
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
|
||||
|
||||
protected:
|
||||
//
|
||||
// The following methods implement the DfsHloVisitor interface.
|
||||
|
@ -76,10 +76,11 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
|
||||
// Since CUDA 9.0, all GPU versions are included in a single file
|
||||
const char* unified_libdevice_filename = "libdevice.10.bc";
|
||||
std::vector<string> unified_libdevice_files;
|
||||
tensorflow::Env::Default()->GetMatchingPaths(
|
||||
const tensorflow::Status status =
|
||||
tensorflow::Env::Default()->GetMatchingPaths(
|
||||
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
|
||||
&unified_libdevice_files);
|
||||
if( unified_libdevice_files.size() == 1 ) {
|
||||
if (status.ok() && unified_libdevice_files.size() == 1) {
|
||||
return unified_libdevice_filename;
|
||||
}
|
||||
// There are only four libdevice files: compute_{20,30,35,50}. Each GPU
|
||||
|
@ -77,7 +77,7 @@ class HloOrdering {
|
||||
// Precondition: 'a' and 'b' are in the same computation.
|
||||
//
|
||||
// Derived classes should implement this method for determining order of
|
||||
// instructions in the same comptuation. ExecutesBefore() analyzes the
|
||||
// instructions in the same computation. ExecutesBefore() analyzes the
|
||||
// callgraph and uses this method to determine ordering of instructions in
|
||||
// different computations.
|
||||
virtual bool ExecutesBeforeInSameComputation(
|
||||
|
@ -50,7 +50,7 @@ class WhileTest : public ClientLibraryTestBase {};
|
||||
// while (result < 5) {
|
||||
// result = result + 1;
|
||||
// }
|
||||
TEST_F(WhileTest, WhileWithScalarResult) {
|
||||
TEST_F(WhileTest, WhileWithScalarS32Result) {
|
||||
auto result_shape = ShapeUtil::MakeShape(S32, {});
|
||||
|
||||
// Create a computation for the condition: repeat for 5 iterations.
|
||||
@ -81,6 +81,43 @@ TEST_F(WhileTest, WhileWithScalarResult) {
|
||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||
}
|
||||
|
||||
// Tests a while node when the result type T is S64.
|
||||
//
|
||||
// int32 result = 0;
|
||||
// while (result < 5) {
|
||||
// result = result + 1;
|
||||
// }
|
||||
TEST_F(WhileTest, WhileWithScalarS64Result) {
|
||||
auto result_shape = ShapeUtil::MakeShape(S64, {});
|
||||
|
||||
// Create a computation for the condition: repeat for 5 iterations.
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
builder.Gt(builder.ConstantR0<int64>(5), prev);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body: add 1 to the result variable.
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto input = builder.ConstantR0<int64>(1);
|
||||
auto result = builder.Add(input, prev);
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto init = builder.ConstantR0<int64>(0);
|
||||
auto result = builder.While(condition, body, init);
|
||||
auto shape = builder.GetShape(result).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR0<int64>(&builder, 5, {});
|
||||
}
|
||||
|
||||
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
|
||||
auto result_shape = ShapeUtil::MakeShape(S32, {});
|
||||
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
|
@ -33,6 +33,7 @@ py_library(
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/hooks",
|
||||
"//tensorflow/contrib/image:distort_image_py",
|
||||
"//tensorflow/contrib/image:image_py",
|
||||
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
|
||||
"//tensorflow/contrib/imperative",
|
||||
|
@ -20,10 +20,10 @@ import android.os.Build.VERSION;
|
||||
import android.os.Trace;
|
||||
import android.text.TextUtils;
|
||||
import android.util.Log;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
@ -79,24 +79,32 @@ public class TensorFlowInferenceInterface {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
|
||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||
byte[] graphDef = new byte[is.available()];
|
||||
final int numBytesRead = is.read(graphDef);
|
||||
if (numBytesRead != graphDef.length) {
|
||||
throw new IOException(
|
||||
"read error: read only "
|
||||
+ numBytesRead
|
||||
+ " of the graph, expected to read "
|
||||
+ graphDef.length);
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
is.close();
|
||||
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
@ -121,13 +129,13 @@ public class TensorFlowInferenceInterface {
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
|
||||
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||
int numBytesRead;
|
||||
@ -143,7 +151,7 @@ public class TensorFlowInferenceInterface {
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
@ -309,8 +317,8 @@ public class TensorFlowInferenceInterface {
|
||||
|
||||
/**
|
||||
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
|
||||
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of
|
||||
* bytes, not a Java {@code String} (which is a sequence of characters).
|
||||
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
|
||||
* a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[] src) {
|
||||
addFeed(inputName, Tensor.create(src));
|
||||
@ -318,9 +326,8 @@ public class TensorFlowInferenceInterface {
|
||||
|
||||
/**
|
||||
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
|
||||
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string"
|
||||
* is an arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of
|
||||
* characters).
|
||||
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
|
||||
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[][] src) {
|
||||
addFeed(inputName, Tensor.create(src));
|
||||
|
@ -151,7 +151,7 @@ def convert_to_universal_format(dtec, sorted_feature_names,
|
||||
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
|
||||
inequality_test.threshold.float_value = split.threshold
|
||||
elif node_type == "sparse_float_binary_split_default_right":
|
||||
split = gtflow_node.sparse_float_binary_split_default_right
|
||||
split = gtflow_node.sparse_float_binary_split_default_right.split
|
||||
node.default_direction = (
|
||||
generic_tree_model_pb2.BinaryNode.RIGHT)
|
||||
feature_id = split.feature_column + num_dense
|
||||
|
@ -329,7 +329,6 @@ add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/cudnn_rnn/python/ops")
|
||||
add_python_module("tensorflow/contrib/data")
|
||||
add_python_module("tensorflow/contrib/data/python")
|
||||
add_python_module("tensorflow/contrib/data/python/framework")
|
||||
add_python_module("tensorflow/contrib/data/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/data/python/ops")
|
||||
add_python_module("tensorflow/contrib/data/python/util")
|
||||
@ -362,6 +361,8 @@ add_python_module("tensorflow/contrib/framework/python/framework")
|
||||
add_python_module("tensorflow/contrib/framework/python/ops")
|
||||
add_python_module("tensorflow/contrib/gan")
|
||||
add_python_module("tensorflow/contrib/gan/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/eval")
|
||||
add_python_module("tensorflow/contrib/gan/python/eval/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/features")
|
||||
add_python_module("tensorflow/contrib/gan/python/features/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/losses")
|
||||
|
@ -20,6 +20,7 @@
|
||||
@@FixedLengthRecordDataset
|
||||
@@TextLineDataset
|
||||
|
||||
@@batch_and_drop_remainder
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
@@group_by_window
|
||||
@ -32,6 +33,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import batch_and_drop_remainder
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
|
||||
|
@ -1,48 +0,0 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "function",
|
||||
srcs = ["function.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "function_test",
|
||||
size = "medium",
|
||||
srcs = ["function_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":function",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,275 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An experimental fork of the Python TensorFlow-function library.
|
||||
|
||||
NOTE: functions are currently experimental and subject to change!
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# NOTE(mrry): This is an experimental extension of a core class that wasn't
|
||||
# designed to be extended, so we disable protected access checks for the
|
||||
# whole file.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
class _ExperimentalFuncGraph(function._FuncGraph):
|
||||
"""A helper for construction a function (supporting capture-by-value).
|
||||
|
||||
_ExperimentalFuncGraph overrides ops.Graph's create_op() so that we can keep
|
||||
track of every inputs into every op created inside the function. If
|
||||
any input is from other graphs, we keep track of it in self.capture
|
||||
and substitute the input with a place holder.
|
||||
|
||||
Each captured input's corresponding place holder is converted into a
|
||||
function argument and the caller passes in the captured tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, capture_by_value, *args, **kwargs):
|
||||
super(_ExperimentalFuncGraph, self).__init__(*args, **kwargs)
|
||||
self._capture_by_value = capture_by_value
|
||||
self._building_function = True
|
||||
self._outer_graph = ops.get_default_graph()
|
||||
self._vscope = vs.get_variable_scope()
|
||||
self._old_custom_getter = self._vscope.custom_getter
|
||||
self._captured = {}
|
||||
self.extra_inputs = []
|
||||
self.extra_args = []
|
||||
self.extra_vars = []
|
||||
|
||||
def create_op(self, op_type, inputs, data_types, **kwargs):
|
||||
for i, x in enumerate(inputs):
|
||||
if x.graph is not self:
|
||||
# Referring to a tensor from other graph.
|
||||
if x in self._captured:
|
||||
# Captured already.
|
||||
inputs[i] = self._captured[x]
|
||||
elif self._capture_by_value:
|
||||
inputs[i] = self._add_tensor_and_parents(x)
|
||||
else:
|
||||
# Substitute with a placeholder.
|
||||
self.extra_inputs.append(x)
|
||||
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
|
||||
# pylint: disable=protected-access
|
||||
ph._handle_data = x._handle_data
|
||||
# pylint: enable=protected-access
|
||||
inputs[i] = ph
|
||||
self._captured[x] = ph
|
||||
self.extra_args.append(ph)
|
||||
return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
|
||||
data_types, **kwargs)
|
||||
|
||||
def _add_tensor_and_parents(self, tensor):
|
||||
op = self._add_op_and_parents(tensor.op)
|
||||
return op.outputs[tensor.value_index]
|
||||
|
||||
def _add_op_and_parents(self, op):
|
||||
op_def = graph_to_function_def._get_op_def(op)
|
||||
if op_def.is_stateful:
|
||||
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
||||
"by value." % (op.name, op.type))
|
||||
elif op.type in ("Placeholder", "PlaceholderV2"):
|
||||
raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
|
||||
"by value." % (op.name, op.type))
|
||||
|
||||
captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
|
||||
|
||||
captured_op = self.create_op(op.type, captured_inputs,
|
||||
[o.dtype for o in op.outputs],
|
||||
name=op.name, attrs=op.node_def.attr,
|
||||
op_def=op_def)
|
||||
|
||||
for t, captured_t in zip(op.outputs, captured_op.outputs):
|
||||
self._captured[t] = captured_t
|
||||
|
||||
return captured_op
|
||||
|
||||
|
||||
class _ExperimentalDefinedFunction(function._DefinedFunction):
|
||||
"""Overrides _DefinedFunction with support for capture-by-value."""
|
||||
|
||||
def __init__(self,
|
||||
func,
|
||||
argnames,
|
||||
input_types,
|
||||
func_name=None,
|
||||
grad_func=None,
|
||||
python_grad_func=None,
|
||||
out_names=None,
|
||||
shape_func=None,
|
||||
capture_by_value=False,
|
||||
**kwargs):
|
||||
"""Creates an _ExperimentalDefinedFunction.
|
||||
|
||||
Args:
|
||||
func: A python callable which constructs a tf function body.
|
||||
argnames: A list of strings for function argument names.
|
||||
input_types: The function's argument types. Can be a tuple, list of
|
||||
tf data types.
|
||||
func_name: The function name. Defaults to None, in which derives from
|
||||
'func'.
|
||||
grad_func: This function's gradient function, if not None. Defaults
|
||||
to None.
|
||||
python_grad_func: A python callable implementing the gradient of
|
||||
the function python-side.
|
||||
out_names: An optional list of strings for the function return value
|
||||
names.
|
||||
shape_func: An optional function mapping an op to a list of static
|
||||
output shapes.
|
||||
capture_by_value: Boolean (defaults to False). If True, captured values
|
||||
will be copied into the function body.
|
||||
**kwargs: The keyword arguments. **kwargs is passed to every call
|
||||
site of this function.
|
||||
|
||||
Raises:
|
||||
ValueError: The function definition is invalid.
|
||||
"""
|
||||
super(_ExperimentalDefinedFunction, self).__init__(
|
||||
func, argnames, input_types, func_name, grad_func, python_grad_func,
|
||||
out_names, shape_func, **kwargs)
|
||||
self._capture_by_value = capture_by_value
|
||||
|
||||
def _create_definition_if_needed(self):
|
||||
"""Creates the function definition if it's not created yet."""
|
||||
with context.graph_mode():
|
||||
self._create_definition_if_needed_impl()
|
||||
|
||||
def _create_definition_if_needed_impl(self):
|
||||
"""You're looking for _create_definition_if_needed(), not this."""
|
||||
|
||||
if self._definition is not None:
|
||||
return
|
||||
|
||||
# Create the func_def object.
|
||||
temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
|
||||
with temp_graph.as_default():
|
||||
# List of placeholders for the function_def.
|
||||
inputs = []
|
||||
for (argname, argtype) in self._args:
|
||||
argholder = array_ops.placeholder(argtype, name=argname)
|
||||
inputs.append(argholder)
|
||||
# Call func and gather the output tensors.
|
||||
with vs.variable_scope("", custom_getter=temp_graph.getvar):
|
||||
outputs = self._func(*inputs)
|
||||
# If func only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
if any([_ is None for _ in outputs]):
|
||||
raise ValueError("Function can not return None.")
|
||||
# Ensures each output is a Tensor.
|
||||
outputs = [ops.convert_to_tensor(_) for _ in outputs]
|
||||
self._extra_inputs = temp_graph.extra_inputs
|
||||
inputs.extend(temp_graph.extra_args)
|
||||
self._sub_functions = temp_graph._functions
|
||||
|
||||
# Build the FunctionDef
|
||||
self._definition = graph_to_function_def.graph_to_function_def(
|
||||
temp_graph, temp_graph.get_operations(), inputs, outputs,
|
||||
out_names=self._out_names)
|
||||
|
||||
# Extra kwargs are treated as attrs on the function def.
|
||||
sig_pre_func_name = self._func_name or function._get_func_name(self._func)
|
||||
kwargs_attr = function._parse_kwargs_as_attrs(
|
||||
sig_pre_func_name, **self._extra_kwargs)
|
||||
for k in kwargs_attr:
|
||||
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
||||
|
||||
# Hash the definition and its dependencies.
|
||||
self._hash_str = self._create_hash_str(
|
||||
self._definition.signature.input_arg,
|
||||
self._definition.signature.output_arg,
|
||||
self._definition.node_def)
|
||||
|
||||
# Finally, we decide the function name to use. If not specified,
|
||||
# make up something which is almost certainly unique (but deterministic).
|
||||
if not self._func_name:
|
||||
self._func_name = "_".join([function._get_func_name(self._func),
|
||||
self._hash_str])
|
||||
self._definition.signature.name = self._func_name
|
||||
if self._func.__doc__:
|
||||
self._definition.signature.description = self._func.__doc__
|
||||
|
||||
|
||||
class Defun(function.Defun):
|
||||
"""Experimental version of Defun supporting capture-by-value."""
|
||||
|
||||
def __init__(self, *input_types, **kwargs):
|
||||
"""Create an experimental `Defun` decorator.
|
||||
|
||||
Args:
|
||||
*input_types: A list of `tf.DType`
|
||||
**kwargs: Optional keyword arguments (see `function.Defun`) plus:
|
||||
capture_by_value - Boolean (defaults to False). If True, captured values
|
||||
will be copied into the function body.
|
||||
"""
|
||||
super(Defun, self).__init__(*input_types, **kwargs)
|
||||
|
||||
def __call__(self, func):
|
||||
# Various sanity checks on the callable func.
|
||||
if not callable(func):
|
||||
raise ValueError("func %s must be callable" % func)
|
||||
|
||||
# Func should not use kwargs and defaults.
|
||||
argspec = tf_inspect.getargspec(func)
|
||||
if argspec.keywords or argspec.defaults:
|
||||
raise ValueError("Functions with argument defaults or keyword "
|
||||
"arguments are not supported.")
|
||||
|
||||
# Computes how many arguments 'func' has.
|
||||
min_args = len(argspec.args)
|
||||
max_args = min_args
|
||||
if argspec.varargs:
|
||||
max_args = 1000000
|
||||
argnames = argspec.args
|
||||
if tf_inspect.ismethod(func):
|
||||
# 1st argument is the "class" type.
|
||||
min_args -= 1
|
||||
argnames = argnames[1:]
|
||||
|
||||
if self._input_types:
|
||||
# If Defun is given a list of types for the inputs, the number
|
||||
# of input types should be compatible with 'func'.
|
||||
num = len(self._input_types)
|
||||
if num < min_args or num > max_args:
|
||||
raise ValueError(
|
||||
"The function has fewer arguments than the number of specified "
|
||||
"input types.")
|
||||
return _ExperimentalDefinedFunction(
|
||||
func, argnames, self._input_types, self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# 'func' expects no arguments and input types is an empty list.
|
||||
if min_args == 0 and max_args == 0:
|
||||
return _ExperimentalDefinedFunction(
|
||||
func, [], [], self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# Input types are unknown. It's an overloaded function and hence
|
||||
# its definition needs to be deferred until it's called.
|
||||
return function._OverloadedFunction(
|
||||
func, argnames, self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names, **self._extra_kwargs)
|
@ -1,59 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for experimental capture-by-value feature in TF functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.data.python.framework import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionTest(test.TestCase):
|
||||
|
||||
def testCaptureByValue(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
w = constant_op.constant([[1.0]])
|
||||
b = constant_op.constant([2.0])
|
||||
|
||||
# Foo() captures w and b.
|
||||
@function.Defun(dtypes.float32, capture_by_value=True)
|
||||
def Foo(x):
|
||||
|
||||
# Plus() captures b.
|
||||
@function.Defun(dtypes.float32, capture_by_value=True)
|
||||
def Plus(y):
|
||||
return y + b
|
||||
|
||||
self.assertEqual(0, len(Plus.captured_inputs))
|
||||
|
||||
return Plus(math_ops.matmul(w, x))
|
||||
|
||||
y = Foo(constant_op.constant([[10.]]))
|
||||
|
||||
self.assertEqual(0, len(Foo.captured_inputs))
|
||||
|
||||
with self.test_session(graph=g):
|
||||
self.assertAllEqual(y.eval(), [[12.0]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -264,6 +264,7 @@ py_test(
|
||||
srcs = ["resample_test.py"],
|
||||
shard_count = 2,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["noasan"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -333,6 +333,55 @@ class BatchDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testBatchAndDropRemainder(self):
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7))
|
||||
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.apply(dataset_ops.batch_and_drop_remainder(batch_size))
|
||||
.make_initializable_iterator())
|
||||
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
for test_batch_size in [1, 3, 7, 10]:
|
||||
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
|
||||
num_batches = 7 // test_batch_size
|
||||
for i in range(num_batches):
|
||||
result = sess.run(next_element)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(test_batch_size):
|
||||
self.assertAllEqual(component[(i * test_batch_size + j)],
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testBatchAndDropRemainderShapeInference(self):
|
||||
components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(
|
||||
dtypes.int32, shape=[None]), array_ops.placeholder(
|
||||
dtypes.int32, shape=[20, 30])))
|
||||
|
||||
# Test with a statically known batch size.
|
||||
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.apply(dataset_ops.batch_and_drop_remainder(128)))
|
||||
|
||||
self.assertIs(None, dataset.output_shapes[0].ndims)
|
||||
self.assertEqual([128], dataset.output_shapes[1][0].as_list())
|
||||
self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
|
||||
|
||||
# Test with a dynamic batch size: the static shape will be unknown, because
|
||||
# `batch_size` is a placeholder.
|
||||
batch_size = array_ops.placeholder(dtypes.int64)
|
||||
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.apply(dataset_ops.batch_and_drop_remainder(batch_size)))
|
||||
|
||||
self.assertIs(None, dataset.output_shapes[0].ndims)
|
||||
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
|
||||
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,13 +22,17 @@ import threading
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.util import nest
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class DatasetConstructorTest(test.TestCase):
|
||||
@ -475,6 +479,75 @@ class DatasetConstructorTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSplitPipelineFailsWithPlacementError(self):
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
|
||||
# Define a pipeline that attempts to use variables on two
|
||||
# different devices.
|
||||
#
|
||||
# Initialize the variables before creating to iterator, to avoid the
|
||||
# placement algorithm overriding the DT_RESOURCE colocation constraints.
|
||||
with ops.device("/cpu:0"):
|
||||
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_0.read_value())
|
||||
sess.run(var_0.initializer)
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_1.read_value())
|
||||
sess.run(var_1.initializer)
|
||||
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Trying to access resource located in device"):
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
def testRestructureDataset(self):
|
||||
components = (array_ops.placeholder(dtypes.int32),
|
||||
(array_ops.placeholder(dtypes.int32, shape=[None]),
|
||||
array_ops.placeholder(dtypes.int32, shape=[20, 30])))
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
|
||||
i32 = dtypes.int32
|
||||
|
||||
test_cases = [((i32, i32, i32), None),
|
||||
(((i32, i32), i32), None),
|
||||
((i32, i32, i32), (None, None, None)),
|
||||
((i32, i32, i32), ([17], [17], [20, 30]))]
|
||||
|
||||
for new_types, new_shape_lists in test_cases:
|
||||
# pylint: disable=protected-access
|
||||
new = dataset_ops._RestructuredDataset(
|
||||
dataset, new_types, new_shape_lists)
|
||||
# pylint: enable=protected-access
|
||||
self.assertEqual(new_types, new.output_types)
|
||||
if new_shape_lists is not None:
|
||||
for expected_shape_list, shape in zip(
|
||||
nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
|
||||
if expected_shape_list is None:
|
||||
self.assertIs(None, shape.ndims)
|
||||
else:
|
||||
self.assertEqual(expected_shape_list, shape.as_list())
|
||||
|
||||
fail_cases = [((i32, dtypes.int64, i32), None),
|
||||
((i32, i32, i32, i32), None),
|
||||
((i32, i32, i32), ((None, None), None)),
|
||||
((i32, i32, i32), (None, None, None, None)),
|
||||
((i32, i32, i32), (None, [None], [21, 30]))]
|
||||
|
||||
for new_types, new_shape_lists in fail_cases:
|
||||
with self.assertRaises(ValueError):
|
||||
# pylint: disable=protected-access
|
||||
new = dataset_ops._RestructuredDataset(
|
||||
dataset, new_types, new_shape_lists)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -21,9 +21,11 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import device_setter
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -50,7 +52,7 @@ class ResampleTest(test.TestCase):
|
||||
seed=27))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
variable_init_op = variables.global_variables_initializer()
|
||||
variable_init_op = variables.local_variables_initializer()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variable_init_op)
|
||||
@ -74,6 +76,22 @@ class ResampleTest(test.TestCase):
|
||||
returned_dist = class_counts / total_returned
|
||||
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
|
||||
|
||||
def testVariableDevicePlacement(self):
|
||||
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
|
||||
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
|
||||
with ops.device(
|
||||
device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
|
||||
dataset = (dataset_ops.Dataset.from_tensor_slices(classes)
|
||||
.shuffle(200, seed=21)
|
||||
.map(lambda c: (c, string_ops.as_string(c))))
|
||||
dataset = dataset_ops.rejection_resample(
|
||||
dataset, target_dist=target_dist, initial_dist=None,
|
||||
class_func=lambda c, _: c, seed=27)
|
||||
|
||||
self.assertEqual(1, len(variables.local_variables()))
|
||||
self.assertEqual(b"",
|
||||
compat.as_bytes(variables.local_variables()[0].device))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -9,7 +9,6 @@ py_library(
|
||||
srcs = ["dataset_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/framework:function",
|
||||
"//tensorflow/contrib/data/python/util:nest",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -17,6 +16,7 @@ py_library(
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
@ -38,11 +38,11 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_ops",
|
||||
"//tensorflow/contrib/data/python/framework:function",
|
||||
"//tensorflow/contrib/data/python/util:nest",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
@ -23,10 +23,10 @@ import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.framework import function
|
||||
from tensorflow.contrib.data.python.util import nest
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
@ -99,8 +99,9 @@ class Iterator(object):
|
||||
shared_name=shared_name,
|
||||
output_types=nest.flatten(dataset.output_types),
|
||||
output_shapes=nest.flatten(dataset.output_shapes))
|
||||
initializer = gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
|
||||
iterator_resource)
|
||||
with ops.colocate_with(iterator_resource):
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
dataset.make_dataset_resource(), iterator_resource)
|
||||
return Iterator(iterator_resource, initializer, dataset.output_types,
|
||||
dataset.output_shapes)
|
||||
|
||||
@ -291,6 +292,7 @@ class Iterator(object):
|
||||
raise TypeError("Expected output shapes compatible with %r but got "
|
||||
"dataset with output shapes %r." %
|
||||
(self._output_shapes, dataset.output_shapes))
|
||||
with ops.colocate_with(self._iterator_resource):
|
||||
return gen_dataset_ops.make_iterator(
|
||||
dataset.make_dataset_resource(), self._iterator_resource, name=name)
|
||||
|
||||
@ -2404,12 +2406,16 @@ def rejection_resample(dataset,
|
||||
num_classes = (target_dist.shape[0].value or
|
||||
array_ops.shape(target_dist)[0])
|
||||
smoothing_constant = 10
|
||||
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
|
||||
initial_value=array_ops.fill([num_classes],
|
||||
np.int64(smoothing_constant)),
|
||||
trainable=False,
|
||||
name="class_count",
|
||||
dtype=dtypes.int64)
|
||||
# Disable device functions and colocation constraints so that the variable
|
||||
# will be placed with the eventual DT_VARIANT dataset tensor.
|
||||
with ops.colocate_with(None, ignore_existing=True):
|
||||
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
|
||||
initial_value=array_ops.fill([num_classes],
|
||||
np.int64(smoothing_constant)),
|
||||
trainable=False,
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
name="local_class_count",
|
||||
dtype=dtypes.int64)
|
||||
|
||||
def update_estimate_and_tile(c):
|
||||
return array_ops.tile(
|
||||
@ -2519,7 +2525,13 @@ def read_batch_features(file_pattern,
|
||||
dataset = reader(filenames, *reader_args)
|
||||
else:
|
||||
dataset = reader(filenames)
|
||||
dataset = dataset.repeat(num_epochs)
|
||||
if dataset.output_types == (dtypes.string, dtypes.string):
|
||||
dataset = dataset.map(lambda unused_k, v: v)
|
||||
elif dataset.output_types != dtypes.string:
|
||||
raise TypeError("`reader` must be a dataset of `tf.string` values, "
|
||||
"or `(tf.string, tf.string)` key-value pairs.")
|
||||
if num_epochs != 1:
|
||||
dataset = dataset.repeat(num_epochs)
|
||||
if randomize_input:
|
||||
dataset = dataset.shuffle(capacity)
|
||||
dataset = dataset.batch(batch_size)
|
||||
@ -2729,3 +2741,137 @@ def group_by_window(dataset,
|
||||
|
||||
assert window_size_func is not None
|
||||
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)
|
||||
|
||||
|
||||
class _RestructuredDataset(Dataset):
|
||||
"""An internal helper for changing the structure and shape of a dataset."""
|
||||
|
||||
def __init__(self, dataset, output_types, output_shapes=None):
|
||||
"""Creates a new dataset with the given output types and shapes.
|
||||
|
||||
The given `dataset` must have a structure that is convertible:
|
||||
* `dataset.output_types` must be the same as `output_types` module nesting.
|
||||
* Each shape in `dataset.output_shapes` must be compatible with each shape
|
||||
in `output_shapes` (if given).
|
||||
|
||||
Note: This helper permits "unsafe casts" for shapes, equivalent to using
|
||||
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset` object.
|
||||
output_types: A nested structure of `tf.DType` objects.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
|
||||
If omitted, the shapes will be inherited from `dataset`.
|
||||
|
||||
Raises:
|
||||
ValueError: If either `output_types` or `output_shapes` is not compatible
|
||||
with the structure of `dataset`.
|
||||
"""
|
||||
super(_RestructuredDataset, self).__init__()
|
||||
self._dataset = dataset
|
||||
|
||||
# Validate that the types are compatible.
|
||||
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||
flat_original_types = nest.flatten(dataset.output_types)
|
||||
flat_new_types = nest.flatten(output_types)
|
||||
if flat_original_types != flat_new_types:
|
||||
raise ValueError(
|
||||
"Dataset with output types %r cannot be restructured to have output "
|
||||
"types %r" % (dataset.output_types, output_types))
|
||||
|
||||
self._output_types = output_types
|
||||
|
||||
if output_shapes is None:
|
||||
# Inherit shapes from the original `dataset`.
|
||||
self._output_shapes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(dataset.output_shapes))
|
||||
else:
|
||||
# Validate that the shapes are compatible.
|
||||
nest.assert_same_structure(output_types, output_shapes)
|
||||
flat_original_shapes = nest.flatten(dataset.output_shapes)
|
||||
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
|
||||
|
||||
for original_shape, new_shape in zip(flat_original_shapes,
|
||||
flat_new_shapes):
|
||||
if not original_shape.is_compatible_with(new_shape):
|
||||
raise ValueError(
|
||||
"Dataset with output shapes %r cannot be restructured to have "
|
||||
"incompatible output shapes %r"
|
||||
% (dataset.output_shapes, output_shapes))
|
||||
self._output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
|
||||
def make_dataset_resource(self):
|
||||
return self._dataset.make_dataset_resource()
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
|
||||
def batch_and_drop_remainder(batch_size):
|
||||
"""A batching transformation that omits the final small batch (if present).
|
||||
|
||||
Like @{tf.contrib.data.Dataset.batch}, this transformation combines
|
||||
consecutive elements of this dataset into batches. However, if the batch
|
||||
size does not evenly divide the input dataset size, this transformation will
|
||||
drop the final smaller element.
|
||||
|
||||
The following example illustrates the difference between this
|
||||
transformation and `Dataset.batch()`:
|
||||
|
||||
```python
|
||||
dataset = tf.contrib.data.Dataset.range(200)
|
||||
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))
|
||||
print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known)
|
||||
```
|
||||
|
||||
By contrast, `dataset.batch(128)` would yield a two-element dataset with
|
||||
shapes `(128,)` and `(72,)`, so the batch dimension would not be statically
|
||||
known.
|
||||
|
||||
Args:
|
||||
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements of this dataset to combine in a single batch.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.contrib.data.Dataset.apply}
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
tensor_batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int64, name="batch_size")
|
||||
|
||||
batched = dataset.batch(tensor_batch_size)
|
||||
flattened = _RestructuredDataset(batched,
|
||||
tuple(nest.flatten(batched.output_types)))
|
||||
|
||||
def _predicate(*xs):
|
||||
"""Return `True` if this element is a full batch."""
|
||||
# Extract the dynamic batch size from the first component of the flattened
|
||||
# batched element.
|
||||
first_component = xs[0]
|
||||
first_component_batch_size = array_ops.shape(
|
||||
first_component, out_type=dtypes.int64)[0]
|
||||
|
||||
return math_ops.equal(first_component_batch_size, tensor_batch_size)
|
||||
|
||||
filtered = flattened.filter(_predicate)
|
||||
|
||||
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
|
||||
|
||||
def _set_first_dimension(shape):
|
||||
return shape.merge_with(
|
||||
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
|
||||
|
||||
known_shapes = nest.map_structure(_set_first_dimension,
|
||||
batched.output_shapes)
|
||||
return _RestructuredDataset(filtered, batched.output_types, known_shapes)
|
||||
|
||||
return _apply_fn
|
||||
|
@ -17,10 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.data.python.framework import function
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
@ -30,6 +30,7 @@ cuda_py_test(
|
||||
":tfe",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class TFETest(test.TestCase):
|
||||
@ -31,6 +31,14 @@ class TFETest(test.TestCase):
|
||||
devices = tfe.list_devices()
|
||||
self.assertEqual(len(devices) - 1, tfe.num_gpus())
|
||||
|
||||
def testCallingEnableEagerExecutionMoreThanOnce(self):
|
||||
# Note that eager.test.main() has already invoked enable_eager_exceution().
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"Do not call tfe\.%s more than once in the same process" %
|
||||
tfe.enable_eager_execution.__name__):
|
||||
tfe.enable_eager_execution()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -25,11 +25,46 @@ py_library(
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dnn",
|
||||
":extenders",
|
||||
":head",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn",
|
||||
srcs = ["python/estimator/dnn.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:dnn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dnn_test",
|
||||
size = "small",
|
||||
srcs = ["python/estimator/dnn_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":dnn",
|
||||
":head",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python/estimator:dnn_testing_utils",
|
||||
"//tensorflow/python/estimator:export_export",
|
||||
"//tensorflow/python/estimator:numpy_io",
|
||||
"//tensorflow/python/estimator:prediction_keys",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "extenders",
|
||||
srcs = [
|
||||
@ -68,6 +103,37 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python/estimator:export_output",
|
||||
"//tensorflow/python/estimator:head",
|
||||
"//tensorflow/python/estimator:metric_keys",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:prediction_keys",
|
||||
"//tensorflow/python/ops/losses",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "head_test",
|
||||
size = "small",
|
||||
srcs = ["python/estimator/head_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":head",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/estimator:metric_keys",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:prediction_keys",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.estimator.python.estimator.dnn import *
|
||||
from tensorflow.contrib.estimator.python.estimator.extenders import *
|
||||
from tensorflow.contrib.estimator.python.estimator.head import *
|
||||
|
||||
@ -29,7 +30,9 @@ _allowed_symbols = [
|
||||
'add_metrics',
|
||||
'binary_classification_head',
|
||||
'multi_class_head',
|
||||
'multi_label_head',
|
||||
'regression_head',
|
||||
'DNNEstimator',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
||||
|
134
tensorflow/contrib/estimator/python/estimator/dnn.py
Normal file
134
tensorflow/contrib/estimator/python/estimator/dnn.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Deep Neural Network estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import estimator
|
||||
from tensorflow.python.estimator.canned import dnn as dnn_lib
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
class DNNEstimator(estimator.Estimator):
|
||||
"""An estimator for TensorFlow DNN models with user-specified head.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
sparse_feature_a = sparse_column_with_hash_bucket(...)
|
||||
sparse_feature_b = sparse_column_with_hash_bucket(...)
|
||||
|
||||
sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
|
||||
...)
|
||||
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
|
||||
...)
|
||||
|
||||
estimator = DNNEstimator(
|
||||
head=tf.contrib.estimator.multi_label_head(n_classes=3),
|
||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||
hidden_units=[1024, 512, 256])
|
||||
|
||||
# Or estimator using the ProximalAdagradOptimizer optimizer with
|
||||
# regularization.
|
||||
estimator = DNNEstimator(
|
||||
head=tf.contrib.estimator.multi_label_head(n_classes=3),
|
||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||
hidden_units=[1024, 512, 256],
|
||||
optimizer=tf.train.ProximalAdagradOptimizer(
|
||||
learning_rate=0.1,
|
||||
l1_regularization_strength=0.001
|
||||
))
|
||||
|
||||
# Input builders
|
||||
def input_fn_train: # returns x, y
|
||||
pass
|
||||
estimator.train(input_fn=input_fn_train, steps=100)
|
||||
|
||||
def input_fn_eval: # returns x, y
|
||||
pass
|
||||
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
|
||||
def input_fn_predict: # returns x, None
|
||||
pass
|
||||
predictions = estimator.predict(input_fn=input_fn_predict)
|
||||
```
|
||||
|
||||
Input of `train` and `evaluate` should have following features,
|
||||
otherwise there will be a `KeyError`:
|
||||
|
||||
* if `weight_column` is not `None`, a feature with
|
||||
`key=weight_column` whose value is a `Tensor`.
|
||||
* for each `column` in `feature_columns`:
|
||||
- if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
|
||||
whose `value` is a `SparseTensor`.
|
||||
- if `column` is a `_WeightedCategoricalColumn`, two features: the first
|
||||
with `key` the id column name, the second with `key` the weight column
|
||||
name. Both features' `value` must be a `SparseTensor`.
|
||||
- if `column` is a `_DenseColumn`, a feature with `key=column.name`
|
||||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss and predicted output are determined by the specified head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
head,
|
||||
hidden_units,
|
||||
feature_columns,
|
||||
model_dir=None,
|
||||
optimizer='Adagrad',
|
||||
activation_fn=nn.relu,
|
||||
dropout=None,
|
||||
input_layer_partitioner=None,
|
||||
config=None):
|
||||
"""Initializes a `DNNClassifier` instance.
|
||||
|
||||
Args:
|
||||
head: A `_Head` instance constructed with a method such as
|
||||
`tf.contrib.estimator.multi_label_head`.
|
||||
hidden_units: Iterable of number hidden units per layer. All layers are
|
||||
fully connected. Ex. `[64, 32]` means first layer has 64 nodes and
|
||||
second one has 32.
|
||||
feature_columns: An iterable containing all the feature columns used by
|
||||
the model. All items in the set should be instances of classes derived
|
||||
from `_FeatureColumn`.
|
||||
model_dir: Directory to save model parameters, graph and etc. This can
|
||||
also be used to load checkpoints from the directory into a estimator to
|
||||
continue training a previously saved model.
|
||||
optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
|
||||
to Adagrad optimizer.
|
||||
activation_fn: Activation function applied to each layer. If `None`, will
|
||||
use `tf.nn.relu`.
|
||||
dropout: When not `None`, the probability we will drop out a given
|
||||
coordinate.
|
||||
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
|
||||
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
"""
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
hidden_units=hidden_units,
|
||||
feature_columns=tuple(feature_columns or []),
|
||||
optimizer=optimizer,
|
||||
activation_fn=activation_fn,
|
||||
dropout=dropout,
|
||||
input_layer_partitioner=input_layer_partitioner,
|
||||
config=config)
|
||||
super(DNNEstimator, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
153
tensorflow/contrib/estimator/python/estimator/dnn_test.py
Normal file
153
tensorflow/contrib/estimator/python/estimator/dnn_test.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for dnn.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.estimator.python.estimator import dnn
|
||||
from tensorflow.contrib.estimator.python.estimator import head as head_lib
|
||||
from tensorflow.python.estimator.canned import dnn_testing_utils
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.feature_column import feature_column
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
|
||||
|
||||
def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
|
||||
"""Returns a DNNEstimator that uses regression_head."""
|
||||
return dnn.DNNEstimator(
|
||||
head=head_lib.regression_head(
|
||||
weight_column=weight_column, label_dimension=label_dimension),
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
class DNNEstimatorEvaluateTest(
|
||||
dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
|
||||
|
||||
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
|
||||
test.TestCase.__init__(self, methodName)
|
||||
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
|
||||
self, _dnn_estimator_fn)
|
||||
|
||||
|
||||
class DNNEstimatorPredictTest(
|
||||
dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
|
||||
|
||||
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
|
||||
test.TestCase.__init__(self, methodName)
|
||||
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
|
||||
self, _dnn_estimator_fn)
|
||||
|
||||
|
||||
class DNNEstimatorTrainTest(
|
||||
dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
|
||||
|
||||
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
|
||||
test.TestCase.__init__(self, methodName)
|
||||
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
|
||||
self, _dnn_estimator_fn)
|
||||
|
||||
|
||||
class DNNEstimatorIntegrationTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
if self._model_dir:
|
||||
writer_cache.FileWriterCache.clear()
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
def _test_complete_flow(
|
||||
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
|
||||
label_dimension, batch_size):
|
||||
feature_columns = [
|
||||
feature_column.numeric_column('x', shape=(input_dimension,))]
|
||||
est = dnn.DNNEstimator(
|
||||
head=head_lib.regression_head(label_dimension=label_dimension),
|
||||
hidden_units=(2, 2),
|
||||
feature_columns=feature_columns,
|
||||
model_dir=self._model_dir)
|
||||
|
||||
# TRAIN
|
||||
num_steps = 10
|
||||
est.train(train_input_fn, steps=num_steps)
|
||||
|
||||
# EVALUTE
|
||||
scores = est.evaluate(eval_input_fn)
|
||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertIn('loss', six.iterkeys(scores))
|
||||
|
||||
# PREDICT
|
||||
predictions = np.array([
|
||||
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
||||
for x in est.predict(predict_input_fn)
|
||||
])
|
||||
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
|
||||
|
||||
# EXPORT
|
||||
feature_spec = feature_column.make_parse_example_spec(feature_columns)
|
||||
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
|
||||
feature_spec)
|
||||
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
|
||||
serving_input_receiver_fn)
|
||||
self.assertTrue(gfile.Exists(export_dir))
|
||||
|
||||
def test_numpy_input_fn(self):
|
||||
"""Tests complete flow with numpy_input_fn."""
|
||||
label_dimension = 2
|
||||
batch_size = 10
|
||||
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||
data = data.reshape(batch_size, label_dimension)
|
||||
# learn y = x
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
predict_input_fn=predict_input_fn,
|
||||
input_dimension=label_dimension,
|
||||
label_dimension=label_dimension,
|
||||
batch_size=batch_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -18,7 +18,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import model_fn
|
||||
from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.canned import metric_keys
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.summary import summary
|
||||
|
||||
|
||||
def multi_class_head(n_classes,
|
||||
@ -33,7 +46,7 @@ def multi_class_head(n_classes,
|
||||
|
||||
Args:
|
||||
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
||||
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
|
||||
`binary_classification_head`).
|
||||
weight_column: A string or a `_NumericColumn` created by
|
||||
`tf.feature_column.numeric_column` defining feature column representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
@ -123,3 +136,206 @@ def regression_head(weight_column=None,
|
||||
weight_column=weight_column,
|
||||
label_dimension=label_dimension,
|
||||
head_name=head_name)
|
||||
|
||||
|
||||
# TODO(roumposg): Support label_vocabulary.
|
||||
def multi_label_head(n_classes,
|
||||
weight_column=None,
|
||||
thresholds=None,
|
||||
head_name=None):
|
||||
"""Creates a `_Head` for multi-label classification.
|
||||
|
||||
Multi-label classification handles the case where each example may have zero
|
||||
or more associated labels, from a discrete set. This is distinct from
|
||||
`multi_class_head` which has exactly one label per example.
|
||||
|
||||
Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
|
||||
multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
|
||||
`SparseTensor` of class indices.
|
||||
|
||||
Args:
|
||||
n_classes: Number of classes, must be greater than 1 (for 1 class, use
|
||||
`binary_classification_head`).
|
||||
weight_column: A string or a `_NumericColumn` created by
|
||||
`tf.feature_column.numeric_column` defining feature column representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
|
||||
and recall metrics are evaluated for each threshold value. The threshold
|
||||
is applied to the predicted probabilities, i.e. above the threshold is
|
||||
`true`, below is `false`.
|
||||
head_name: name of the head. If provided, summary and metrics keys will be
|
||||
suffixed by `"/" + head_name`.
|
||||
|
||||
Returns:
|
||||
An instance of `_Head` for multi-label classification.
|
||||
|
||||
Raises:
|
||||
ValueError: if `n_classes` or `thresholds` is invalid.
|
||||
"""
|
||||
thresholds = tuple(thresholds) if thresholds else tuple()
|
||||
if n_classes is None or n_classes < 2:
|
||||
raise ValueError(
|
||||
'n_classes must be > 1 for multi-class classification. '
|
||||
'Given: {}'.format(n_classes))
|
||||
for threshold in thresholds:
|
||||
if (threshold <= 0.0) or (threshold >= 1.0):
|
||||
raise ValueError(
|
||||
'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
|
||||
return _MultiLabelHead(
|
||||
n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
|
||||
head_name=head_name)
|
||||
|
||||
|
||||
class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
"""`_Head` for multi-label classification."""
|
||||
|
||||
def __init__(self,
|
||||
n_classes,
|
||||
weight_column=None,
|
||||
thresholds=None,
|
||||
head_name=None):
|
||||
self._n_classes = n_classes
|
||||
self._weight_column = weight_column
|
||||
self._thresholds = thresholds
|
||||
self._head_name = head_name
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
return self._n_classes
|
||||
|
||||
def _process_labels(self, labels):
|
||||
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||
return math_ops.to_int64(
|
||||
sparse_ops.sparse_to_indicator(labels, self._n_classes))
|
||||
msg = ('labels shape must be [batch_size, {}]. '
|
||||
'Given: ').format(self._n_classes)
|
||||
labels_shape = array_ops.shape(labels)
|
||||
check_rank_op = control_flow_ops.Assert(
|
||||
math_ops.equal(array_ops.rank(labels), 2),
|
||||
data=[msg, labels_shape])
|
||||
check_label_dim = control_flow_ops.Assert(
|
||||
math_ops.equal(labels_shape[-1], self._n_classes),
|
||||
data=[msg, labels_shape])
|
||||
with ops.control_dependencies([check_rank_op, check_label_dim]):
|
||||
return array_ops.identity(labels)
|
||||
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""See `Head`."""
|
||||
del mode, features # Unused for this head.
|
||||
processed_labels = self._process_labels(labels)
|
||||
unweighted_loss = losses.sigmoid_cross_entropy(
|
||||
multi_class_labels=processed_labels, logits=logits,
|
||||
reduction=losses.Reduction.NONE)
|
||||
return head_lib.LossAndLabels(
|
||||
unweighted_loss=unweighted_loss,
|
||||
processed_labels=processed_labels)
|
||||
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None, train_op_fn=None):
|
||||
"""See `Head`."""
|
||||
with ops.name_scope('head'):
|
||||
logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access
|
||||
|
||||
# Predict.
|
||||
pred_keys = prediction_keys.PredictionKeys
|
||||
with ops.name_scope(None, 'predictions', (logits,)):
|
||||
probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
|
||||
predictions = {
|
||||
pred_keys.LOGITS: logits,
|
||||
pred_keys.PROBABILITIES: probabilities,
|
||||
}
|
||||
if mode == model_fn.ModeKeys.PREDICT:
|
||||
return model_fn.EstimatorSpec(
|
||||
mode=model_fn.ModeKeys.PREDICT,
|
||||
predictions=predictions,
|
||||
export_outputs={
|
||||
'': export_output.ClassificationOutput(scores=probabilities)
|
||||
})
|
||||
|
||||
# Eval.
|
||||
unweighted_loss, _ = self.create_loss(
|
||||
features=features, mode=mode, logits=logits, labels=labels)
|
||||
# Averages loss over classes.
|
||||
per_example_loss = math_ops.reduce_mean(
|
||||
unweighted_loss, axis=-1, keep_dims=True)
|
||||
weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
per_example_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
return model_fn.EstimatorSpec(
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
predictions=predictions,
|
||||
loss=training_loss,
|
||||
eval_metric_ops=self._eval_metric_ops(
|
||||
labels=labels,
|
||||
probabilities=probabilities,
|
||||
weights=weights,
|
||||
per_example_loss=per_example_loss))
|
||||
|
||||
# Train.
|
||||
if train_op_fn is None:
|
||||
raise ValueError('train_op_fn can not be None.')
|
||||
with ops.name_scope(''):
|
||||
summary.scalar(
|
||||
head_lib._summary_key(self._head_name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access
|
||||
training_loss)
|
||||
summary.scalar(
|
||||
head_lib._summary_key( # pylint:disable=protected-access
|
||||
self._head_name, metric_keys.MetricKeys.LOSS_MEAN),
|
||||
losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights,
|
||||
reduction=losses.Reduction.MEAN))
|
||||
return model_fn.EstimatorSpec(
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
predictions=predictions,
|
||||
loss=training_loss,
|
||||
train_op=train_op_fn(training_loss))
|
||||
|
||||
def _eval_metric_ops(self, labels, probabilities, weights, per_example_loss):
|
||||
"""Returns a dict of metrics for eval_metric_ops."""
|
||||
with ops.name_scope(
|
||||
None, 'metrics', [labels, probabilities, weights, per_example_loss]):
|
||||
keys = metric_keys.MetricKeys
|
||||
metric_ops = {
|
||||
# Estimator already adds a metric for loss.
|
||||
head_lib._summary_key(self._head_name, keys.LOSS_MEAN): # pylint:disable=protected-access
|
||||
metrics_lib.mean(
|
||||
per_example_loss, weights=weights, name=keys.LOSS_MEAN),
|
||||
head_lib._summary_key(self._head_name, keys.AUC): # pylint:disable=protected-access
|
||||
metrics_lib.auc(
|
||||
labels=labels, predictions=probabilities, weights=weights,
|
||||
name=keys.AUC),
|
||||
head_lib._summary_key(self._head_name, keys.AUC_PR): # pylint:disable=protected-access
|
||||
metrics_lib.auc(
|
||||
labels=labels, predictions=probabilities, weights=weights,
|
||||
curve='PR', name=keys.AUC_PR),
|
||||
}
|
||||
for threshold in self._thresholds:
|
||||
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
|
||||
metric_ops[head_lib._summary_key(self._head_name, accuracy_key)] = ( # pylint:disable=protected-access
|
||||
head_lib._accuracy_at_threshold( # pylint:disable=protected-access
|
||||
labels=labels,
|
||||
predictions=probabilities,
|
||||
weights=weights,
|
||||
threshold=threshold,
|
||||
name=accuracy_key))
|
||||
# Precision for positive examples.
|
||||
precision_key = keys.PRECISION_AT_THRESHOLD % threshold
|
||||
metric_ops[head_lib._summary_key(self._head_name, precision_key)] = ( # pylint:disable=protected-access
|
||||
head_lib._precision_at_threshold( # pylint:disable=protected-access
|
||||
labels=labels,
|
||||
predictions=probabilities,
|
||||
weights=weights,
|
||||
threshold=threshold,
|
||||
name=precision_key))
|
||||
# Recall for positive examples.
|
||||
recall_key = keys.RECALL_AT_THRESHOLD % threshold
|
||||
metric_ops[head_lib._summary_key(self._head_name, recall_key)] = ( # pylint:disable=protected-access
|
||||
head_lib._recall_at_threshold( # pylint:disable=protected-access
|
||||
labels=labels,
|
||||
predictions=probabilities,
|
||||
weights=weights,
|
||||
threshold=threshold,
|
||||
name=recall_key))
|
||||
return metric_ops
|
||||
|
570
tensorflow/contrib/estimator/python/estimator/head_test.py
Normal file
570
tensorflow/contrib/estimator/python/estimator/head_test.py
Normal file
@ -0,0 +1,570 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for head."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.estimator.python.estimator import head as head_lib
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.estimator import model_fn
|
||||
from tensorflow.python.estimator.canned import metric_keys
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
|
||||
|
||||
def _initialize_variables(test_case, scaffold):
|
||||
scaffold.finalize()
|
||||
test_case.assertIsNone(scaffold.init_feed_dict)
|
||||
test_case.assertIsNone(scaffold.init_fn)
|
||||
scaffold.init_op.run()
|
||||
scaffold.ready_for_local_init_op.eval()
|
||||
scaffold.local_init_op.run()
|
||||
scaffold.ready_op.eval()
|
||||
test_case.assertIsNotNone(scaffold.saver)
|
||||
|
||||
|
||||
def _assert_simple_summaries(test_case, expected_summaries, summary_str,
|
||||
tol=1e-6):
|
||||
"""Assert summary the specified simple values.
|
||||
|
||||
Args:
|
||||
test_case: test case.
|
||||
expected_summaries: Dict of expected tags and simple values.
|
||||
summary_str: Serialized `summary_pb2.Summary`.
|
||||
tol: Tolerance for relative and absolute.
|
||||
"""
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
test_case.assertAllClose(expected_summaries, {
|
||||
v.tag: v.simple_value for v in summary.value
|
||||
}, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
def _assert_no_hooks(test_case, spec):
|
||||
test_case.assertAllEqual([], spec.training_chief_hooks)
|
||||
test_case.assertAllEqual([], spec.training_hooks)
|
||||
|
||||
|
||||
def _sigmoid(logits):
|
||||
return 1 / (1 + np.exp(-logits))
|
||||
|
||||
|
||||
def _sigmoid_cross_entropy(labels, logits):
|
||||
sigmoid_logits = _sigmoid(logits)
|
||||
return (-labels * np.log(sigmoid_logits)
|
||||
-(1 - labels) * np.log(1 - sigmoid_logits))
|
||||
|
||||
|
||||
class MultiLabelHead(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
|
||||
def test_n_classes_is_none(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'n_classes must be > 1 for multi-class classification\. Given: None'):
|
||||
head_lib.multi_label_head(n_classes=None)
|
||||
|
||||
def test_n_classes_is_1(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'n_classes must be > 1 for multi-class classification\. Given: 1'):
|
||||
head_lib.multi_label_head(n_classes=1)
|
||||
|
||||
def test_threshold_too_small(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'thresholds must be in \(0, 1\) range\. Given: 0\.0'):
|
||||
head_lib.multi_label_head(n_classes=2, thresholds=[0., 0.5])
|
||||
|
||||
def test_threshold_too_large(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
|
||||
head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
|
||||
|
||||
def test_predict(self):
|
||||
n_classes = 4
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
self.assertEqual(n_classes, head.logits_dimension)
|
||||
|
||||
logits = np.array(
|
||||
[[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
|
||||
expected_probabilities = _sigmoid(logits)
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.PREDICT,
|
||||
logits=logits)
|
||||
|
||||
self.assertItemsEqual(
|
||||
('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys())
|
||||
|
||||
# Assert predictions and export_outputs.
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
predictions = sess.run(spec.predictions)
|
||||
self.assertAllClose(logits,
|
||||
predictions[prediction_keys.PredictionKeys.LOGITS])
|
||||
self.assertAllClose(
|
||||
expected_probabilities,
|
||||
predictions[prediction_keys.PredictionKeys.PROBABILITIES])
|
||||
|
||||
self.assertAllClose(
|
||||
expected_probabilities,
|
||||
sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
|
||||
|
||||
def test_weight_should_not_impact_prediction(self):
|
||||
n_classes = 4
|
||||
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
|
||||
self.assertEqual(n_classes, head.logits_dimension)
|
||||
|
||||
logits = np.array(
|
||||
[[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
|
||||
expected_probabilities = _sigmoid(logits)
|
||||
|
||||
weights_2x1 = [[1.], [2.]]
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x': np.array(((42,),), dtype=np.int32),
|
||||
'label_weights': weights_2x1,
|
||||
},
|
||||
mode=model_fn.ModeKeys.PREDICT,
|
||||
logits=logits)
|
||||
|
||||
# Assert predictions and export_outputs.
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
predictions = sess.run(spec.predictions)
|
||||
self.assertAllClose(logits,
|
||||
predictions[prediction_keys.PredictionKeys.LOGITS])
|
||||
self.assertAllClose(
|
||||
expected_probabilities,
|
||||
predictions[prediction_keys.PredictionKeys.PROBABILITIES])
|
||||
|
||||
def test_eval_create_loss(self):
|
||||
"""Tests head.create_loss for eval mode."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
expected_unweighted_loss = _sigmoid_cross_entropy(
|
||||
labels=labels, logits=logits)
|
||||
actual_unweighted_loss, _ = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval())
|
||||
|
||||
def test_eval_create_loss_large_logits(self):
|
||||
"""Tests head.create_loss for eval mode and large logits."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# For large logits, this is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits
|
||||
expected_unweighted_loss = np.array(
|
||||
[[10., 10.], [15., 0.]], dtype=np.float32)
|
||||
actual_unweighted_loss, _ = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_eval_create_loss_sparse_labels(self):
|
||||
"""Tests head.create_loss for eval mode and sparse labels."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=[0, 0, 1],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# For large logits, this is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits
|
||||
expected_unweighted_loss = np.array(
|
||||
[[10., 10.], [15., 0.]], dtype=np.float32)
|
||||
actual_unweighted_loss, actual_labels = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
self.assertAllEqual(expected_labels, actual_labels.eval())
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_eval_create_loss_labels_wrong_shape(self):
|
||||
"""Tests head.create_loss for eval mode when labels has the wrong shape."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int64)
|
||||
actual_unweighted_loss, _ = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels_placeholder)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'):
|
||||
actual_unweighted_loss.eval(
|
||||
{labels_placeholder: np.array([[1], [1]], dtype=np.int64)})
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'):
|
||||
actual_unweighted_loss.eval(
|
||||
{labels_placeholder: np.array([1, 1], dtype=np.int64)})
|
||||
|
||||
def test_eval(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
|
||||
)
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
|
||||
self.assertIsNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
|
||||
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
|
||||
loss, metrics = sess.run((spec.loss, update_ops))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
# Check results of both update (in `metrics`) and value ops.
|
||||
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
|
||||
self.assertAllClose(
|
||||
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
|
||||
rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
def test_eval_with_thresholds(self):
|
||||
n_classes = 2
|
||||
thresholds = [0.25, 0.5, 0.75]
|
||||
head = head_lib.multi_label_head(n_classes, thresholds=thresholds)
|
||||
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
|
||||
)
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
|
||||
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
|
||||
keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4.,
|
||||
keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2.,
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3.,
|
||||
keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4.,
|
||||
keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1.,
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
|
||||
self.assertIsNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
|
||||
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
|
||||
loss, metrics = sess.run((spec.loss, update_ops))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
# Check results of both update (in `metrics`) and value ops.
|
||||
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
|
||||
self.assertAllClose(
|
||||
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
|
||||
rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
def test_eval_with_weights(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, weighted sum over examples.
|
||||
expected_loss = 25.
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x': np.array([[41], [42]], dtype=np.int32),
|
||||
'label_weights': np.array([[1.], [2.]], dtype=np.float32),
|
||||
},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over weighted examples.
|
||||
keys.LOSS_MEAN: expected_loss / 3,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.2000,
|
||||
keys.AUC_PR: 0.7833,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
|
||||
self.assertIsNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
|
||||
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
|
||||
loss, metrics = sess.run((spec.loss, update_ops))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
# Check results of both update (in `metrics`) and value ops.
|
||||
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
|
||||
self.assertAllClose(
|
||||
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
|
||||
rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
def test_train_create_loss_large_logits(self):
|
||||
"""Tests head.create_loss for train mode and large logits."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# For large logits, this is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits
|
||||
expected_unweighted_loss = np.array(
|
||||
[[10., 10.], [15., 0.]], dtype=np.float32)
|
||||
actual_unweighted_loss, _ = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_train(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, sum over weights.
|
||||
expected_loss = 17.5
|
||||
expected_train_result = 'my_train_op'
|
||||
def _train_op_fn(loss):
|
||||
return string_ops.string_join(
|
||||
[constant_op.constant(expected_train_result),
|
||||
string_ops.as_string(loss, precision=3)])
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn)
|
||||
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertEqual({}, spec.eval_metric_ops)
|
||||
self.assertIsNotNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, train_op, and summaries.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNotNone(spec.scaffold.summary_op)
|
||||
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
|
||||
spec.scaffold.summary_op))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
self.assertEqual(
|
||||
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
|
||||
train_result)
|
||||
_assert_simple_summaries(self, {
|
||||
metric_keys.MetricKeys.LOSS: expected_loss,
|
||||
# Average loss over examples.
|
||||
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
|
||||
}, summary_str, tol)
|
||||
|
||||
def test_train_with_weights(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, weighted sum over examples.
|
||||
expected_loss = 25.
|
||||
expected_train_result = 'my_train_op'
|
||||
def _train_op_fn(loss):
|
||||
return string_ops.string_join(
|
||||
[constant_op.constant(expected_train_result),
|
||||
string_ops.as_string(loss, precision=3)])
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x': np.array([[41], [42]], dtype=np.int32),
|
||||
'label_weights': np.array([[1.], [2.]], dtype=np.float32),
|
||||
},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn)
|
||||
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertEqual({}, spec.eval_metric_ops)
|
||||
self.assertIsNotNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, train_op, and summaries.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNotNone(spec.scaffold.summary_op)
|
||||
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
|
||||
spec.scaffold.summary_op))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
self.assertEqual(
|
||||
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
|
||||
train_result)
|
||||
_assert_simple_summaries(self, {
|
||||
metric_keys.MetricKeys.LOSS: expected_loss,
|
||||
# Average loss over weighted examples.
|
||||
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
|
||||
}, summary_str, tol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -221,6 +221,7 @@ tf_py_test(
|
||||
"manual",
|
||||
"noasan", # times out b/63678675
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,10 +14,12 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":eval",
|
||||
":features",
|
||||
":losses",
|
||||
":namedtuples",
|
||||
":train",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -73,6 +75,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "eval",
|
||||
srcs = ["python/eval/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":classifier_metrics",
|
||||
":eval_utils",
|
||||
":summaries",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "losses",
|
||||
srcs = ["python/losses/__init__.py"],
|
||||
@ -257,6 +271,105 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classifier_metrics",
|
||||
srcs = [
|
||||
"python/eval/python/classifier_metrics.py",
|
||||
"python/eval/python/classifier_metrics_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classifier_metrics_test",
|
||||
srcs = ["python/eval/python/classifier_metrics_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":classifier_metrics",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "eval_utils",
|
||||
srcs = [
|
||||
"python/eval/python/eval_utils.py",
|
||||
"python/eval/python/eval_utils_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "eval_utils_test",
|
||||
srcs = ["python/eval/python/eval_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":eval_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "summaries",
|
||||
srcs = [
|
||||
"python/eval/python/summaries.py",
|
||||
"python/eval/python/summaries_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":eval_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/losses",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "summaries_test",
|
||||
srcs = ["python/eval/python/summaries_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":namedtuples",
|
||||
":summaries",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse TFGAN into a tiered namespace.
|
||||
from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin
|
||||
from tensorflow.contrib.gan.python import features
|
||||
from tensorflow.contrib.gan.python import losses
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
@ -32,6 +33,7 @@ from tensorflow.contrib.gan.python.train import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'eval',
|
||||
'features',
|
||||
'losses',
|
||||
]
|
||||
|
39
tensorflow/contrib/gan/python/eval/__init__.py
Normal file
39
tensorflow/contrib/gan/python/eval/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||
# pylint: disable=,wildcard-import,unused-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse eval into a single namespace.
|
||||
from tensorflow.contrib.gan.python.eval.python import classifier_metrics
|
||||
from tensorflow.contrib.gan.python.eval.python import eval_utils
|
||||
from tensorflow.contrib.gan.python.eval.python import summaries
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python.classifier_metrics import *
|
||||
from tensorflow.contrib.gan.python.eval.python.eval_utils import *
|
||||
from tensorflow.contrib.gan.python.eval.python.summaries import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'classifier_metrics',
|
||||
'summaries',
|
||||
'eval_utils',
|
||||
] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model evaluation tools for TFGAN."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.eval.python.classifier_metrics_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = classifier_metrics_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,401 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model evaluation tools for TFGAN.
|
||||
|
||||
These methods come from https://arxiv.org/abs/1606.03498 and
|
||||
https://arxiv.org/abs/1706.08500.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_graph_def_from_disk',
|
||||
'preprocess_image',
|
||||
'run_image_classifier',
|
||||
'run_inception',
|
||||
'inception_score',
|
||||
'classifier_score',
|
||||
'frechet_inception_distance',
|
||||
'frechet_classifier_distance',
|
||||
]
|
||||
|
||||
|
||||
INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz'
|
||||
INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb'
|
||||
INCEPTION_V3_INPUT = 'inputs'
|
||||
INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0'
|
||||
INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0'
|
||||
_INCEPTION_V3_NUM_CLASSES = 1001
|
||||
_INCEPTION_V3_FINAL_POOL_SIZE = 2048
|
||||
INCEPTION_V3_DEFAULT_IMG_SIZE = 299
|
||||
|
||||
|
||||
def _validate_images(images, image_size):
|
||||
images = ops.convert_to_tensor(images)
|
||||
images.shape.with_rank(4)
|
||||
images.shape.assert_is_compatible_with(
|
||||
[None, image_size, image_size, None])
|
||||
return images
|
||||
|
||||
|
||||
def _matrix_square_root(mat, eps=1e-10):
|
||||
"""Compute symmetric square root of matrix.
|
||||
|
||||
Equivalent to matrix square root when matrix is invertible; note that this is
|
||||
different from an elementwise square root. We want to compute M' where M' =
|
||||
sqrt(mat) such that M' * M' = mat.
|
||||
|
||||
Args:
|
||||
mat: Matrix to take the square root of.
|
||||
eps: Small epsilon such that any element less than eps will not be square
|
||||
rooted to guard against numerical instability.
|
||||
|
||||
Returns:
|
||||
Matrix square root of mat.
|
||||
"""
|
||||
s, u, v = linalg_ops.svd(mat)
|
||||
# sqrt is unstable around 0, just use 0 in such case
|
||||
si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
|
||||
return math_ops.matmul(
|
||||
math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
|
||||
|
||||
|
||||
# Convenience preprocessing function, with fixed defaults.
|
||||
# NOTE: Floating-point inputs are expected to be in [0, 1].
|
||||
# Copied from /tensorflow_models/slim/preprocessing/inception_preprocessing.py.
|
||||
def preprocess_image(
|
||||
image, height=INCEPTION_V3_DEFAULT_IMG_SIZE,
|
||||
width=INCEPTION_V3_DEFAULT_IMG_SIZE, central_fraction=0.875, scope=None):
|
||||
"""Prepare one image for evaluation.
|
||||
|
||||
If height and width are specified it would output an image with that size by
|
||||
applying resize_bilinear.
|
||||
|
||||
If central_fraction is specified it would crop the central fraction of the
|
||||
input image.
|
||||
|
||||
Args:
|
||||
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
|
||||
[0, 1], otherwise it would converted to tf.float32 assuming that the range
|
||||
is [0, MAX], where MAX is largest positive representable number for
|
||||
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
|
||||
height: integer
|
||||
width: integer
|
||||
central_fraction: Optional Float, fraction of the image to crop.
|
||||
scope: Optional scope for name_scope.
|
||||
Returns:
|
||||
3-D float Tensor of prepared image.
|
||||
"""
|
||||
with ops.name_scope(scope, 'eval_image', [image, height, width]):
|
||||
if image.dtype != dtypes.float32:
|
||||
image = image_ops.convert_image_dtype(image, dtype=dtypes.float32)
|
||||
# Crop the central region of the image with an area containing 87.5% of
|
||||
# the original image.
|
||||
image = image_ops.central_crop(image, central_fraction=central_fraction)
|
||||
|
||||
# Resize the image to the specified height and width.
|
||||
image = array_ops.expand_dims(image, 0)
|
||||
image = image_ops.resize_bilinear(image, [height, width],
|
||||
align_corners=False)
|
||||
image = array_ops.squeeze(image, [0])
|
||||
image = (image - 0.5) * 2.0
|
||||
return image
|
||||
|
||||
|
||||
def _kl_divergence(p, p_logits, q):
|
||||
"""Computes the Kullback-Liebler divergence between p and q.
|
||||
|
||||
This function uses p's logits in some places to improve numerical stability.
|
||||
|
||||
Specifically:
|
||||
|
||||
KL(p || q) = sum[ p * log(p / q) ]
|
||||
= sum[ p * ( log(p) - log(q) ) ]
|
||||
= sum[ p * ( log_softmax(p_logits) - log(q) ) ]
|
||||
|
||||
Args:
|
||||
p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
|
||||
example and `j` corresponds to the probability of being in class `j`.
|
||||
p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
|
||||
q: A 1-D floating-point Tensor, where q_j corresponds to the probability
|
||||
of class `j`.
|
||||
|
||||
Returns:
|
||||
KL divergence between two distributions. Output dimension is 1D, one entry
|
||||
per distribution in `p`.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the inputs aren't floating-point.
|
||||
ValueError: If p or p_logits aren't 2D.
|
||||
ValueError: If q isn't 1D.
|
||||
"""
|
||||
for tensor in [p, p_logits, q]:
|
||||
if not tensor.dtype.is_floating:
|
||||
raise ValueError('Input %s must be floating type.', tensor.name)
|
||||
p.shape.assert_has_rank(2)
|
||||
p_logits.shape.assert_has_rank(2)
|
||||
q.shape.assert_has_rank(1)
|
||||
return math_ops.reduce_sum(
|
||||
p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
|
||||
|
||||
|
||||
def get_graph_def_from_disk(filename):
|
||||
"""Get a GraphDef proto from a disk location."""
|
||||
with gfile.FastGFile(filename, 'rb') as f:
|
||||
return graph_pb2.GraphDef.FromString(f.read())
|
||||
|
||||
|
||||
def get_graph_def_from_url_tarball(url, filename):
|
||||
"""Get a GraphDef proto from a tarball on the web."""
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
url, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress)
|
||||
with tarfile.open(tar_filename, 'r:gz') as tar:
|
||||
proto_str = tar.extractfile(filename).read()
|
||||
return graph_pb2.GraphDef.FromString(proto_str)
|
||||
|
||||
|
||||
def _default_graph_def_fn():
|
||||
return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH)
|
||||
|
||||
|
||||
def run_inception(images,
|
||||
graph_def=None,
|
||||
default_graph_def_fn=_default_graph_def_fn,
|
||||
image_size=INCEPTION_V3_DEFAULT_IMG_SIZE,
|
||||
input_tensor=INCEPTION_V3_INPUT,
|
||||
output_tensor=INCEPTION_V3_OUTPUT):
|
||||
"""Run images through a pretrained Inception classifier.
|
||||
|
||||
Args:
|
||||
images: Input tensors. Must be [batch, height, width, channels]. Input shape
|
||||
and values must be in [-1, 1], which can be achieved using
|
||||
`preprocess_image`.
|
||||
graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
|
||||
call `default_graph_def_fn` to get GraphDef.
|
||||
default_graph_def_fn: A function that returns a GraphDef. Used if
|
||||
`graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
|
||||
image_size: Required image width and height. See unit tests for the default
|
||||
values.
|
||||
input_tensor: Name of input Tensor.
|
||||
output_tensor: Name of output Tensor. This function will compute activations
|
||||
at the specified layer. Examples include INCEPTION_V3_OUTPUT and
|
||||
INCEPTION_V3_FINAL_POOL which would result in this function computing
|
||||
the final logits or the penultimate pooling layer.
|
||||
|
||||
Returns:
|
||||
Logits.
|
||||
|
||||
Raises:
|
||||
ValueError: If images are not the correct size.
|
||||
ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
|
||||
"""
|
||||
images = _validate_images(images, image_size)
|
||||
|
||||
if graph_def is None:
|
||||
if default_graph_def_fn is None:
|
||||
raise ValueError('If `graph_def` is `None`, must provide '
|
||||
'`default_graph_def_fn`.')
|
||||
graph_def = default_graph_def_fn()
|
||||
|
||||
activations = run_image_classifier(images, graph_def, input_tensor,
|
||||
output_tensor)
|
||||
if array_ops.rank(activations) != 2:
|
||||
activations = layers.flatten(activations)
|
||||
return activations
|
||||
|
||||
|
||||
def run_image_classifier(tensor, graph_def, input_tensor,
|
||||
output_tensor, scope='RunClassifier'):
|
||||
"""Runs a network from a frozen graph.
|
||||
|
||||
Args:
|
||||
tensor: An Input tensor.
|
||||
graph_def: A GraphDef proto.
|
||||
input_tensor: Name of input tensor in graph def.
|
||||
output_tensor: Name of output tensor in graph def.
|
||||
scope: Name scope for classifier.
|
||||
|
||||
Returns:
|
||||
Classifier output. Shape depends on the classifier used, but is often
|
||||
[batch, classes].
|
||||
|
||||
Raises:
|
||||
ValueError: If `image_size` is not `None`, and `tensor` are not the correct
|
||||
size.
|
||||
"""
|
||||
input_map = {input_tensor: tensor}
|
||||
return_elements = [output_tensor]
|
||||
classifier_output = importer.import_graph_def(
|
||||
graph_def, input_map, return_elements, name=scope)[0]
|
||||
|
||||
return classifier_output
|
||||
|
||||
|
||||
def classifier_score(images, classifier_fn, num_batches=1):
|
||||
"""Classifier score for evaluating a conditional generative model.
|
||||
|
||||
This is based on the Inception Score, but for an arbitrary classifier.
|
||||
|
||||
This technique is described in detail in https://arxiv.org/abs/1606.03498. In
|
||||
summary, this function calculates
|
||||
|
||||
exp( E[ KL(p(y|x) || p(y)) ] )
|
||||
|
||||
which captures how different the network's classification prediction is from
|
||||
the prior distribution over classes.
|
||||
|
||||
Args:
|
||||
images: Images to calculate the classifier score for.
|
||||
classifier_fn: A function that takes images and produces logits based on a
|
||||
classifier.
|
||||
num_batches: Number of batches to split `generated_images` in to in order to
|
||||
efficiently run them through the classifier network.
|
||||
|
||||
Returns:
|
||||
The classifier score. A floating-point scalar.
|
||||
"""
|
||||
generated_images_list = array_ops.split(
|
||||
images, num_or_size_splits=num_batches)
|
||||
|
||||
# Compute the classifier splits using the memory-efficient `map_fn`.
|
||||
logits = functional_ops.map_fn(
|
||||
fn=classifier_fn,
|
||||
elems=array_ops.stack(generated_images_list),
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
swap_memory=True,
|
||||
name='RunClassifier')
|
||||
logits = array_ops.concat(array_ops.unstack(logits), 0)
|
||||
logits.shape.assert_has_rank(2)
|
||||
p = nn_ops.softmax(logits)
|
||||
q = math_ops.reduce_mean(p, axis=0)
|
||||
kl = _kl_divergence(p, logits, q)
|
||||
kl.shape.assert_has_rank(1)
|
||||
log_score = math_ops.reduce_mean(kl)
|
||||
|
||||
return math_ops.exp(log_score)
|
||||
|
||||
|
||||
inception_score = functools.partial(
|
||||
classifier_score,
|
||||
classifier_fn=functools.partial(
|
||||
run_inception, output_tensor=INCEPTION_V3_OUTPUT))
|
||||
|
||||
|
||||
def frechet_classifier_distance(real_images,
|
||||
generated_images,
|
||||
classifier_fn,
|
||||
num_batches=1):
|
||||
"""Classifier distance for evaluating a conditional generative model.
|
||||
|
||||
This is based on the Frechet Inception distance, but for an arbitrary
|
||||
classifier.
|
||||
|
||||
This technique is described in detail in https://arxiv.org/abs/1706.08500.
|
||||
Given two Gaussian distribution with means m and m_w and covariance matrices
|
||||
C and C_w, this function calcuates
|
||||
|
||||
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
|
||||
|
||||
which captures how different the distributions of real images and generated
|
||||
images (or more accurately, their visual features) are. Note that unlike the
|
||||
Inception score, this is a true distance and utilizes information about real
|
||||
world images.
|
||||
|
||||
Args:
|
||||
real_images: Real images to use to compute Frechet Inception distance.
|
||||
generated_images: Generated images to use to compute Frechet Inception
|
||||
distance.
|
||||
classifier_fn: A function that takes images and produces activations
|
||||
based on a classifier.
|
||||
num_batches: Number of batches to split images in to in order to
|
||||
efficiently run them through the classifier network.
|
||||
|
||||
Returns:
|
||||
The Frechet Inception distance. A floating-point scalar.
|
||||
"""
|
||||
|
||||
real_images_list = array_ops.split(
|
||||
real_images, num_or_size_splits=num_batches)
|
||||
generated_images_list = array_ops.split(
|
||||
generated_images, num_or_size_splits=num_batches)
|
||||
|
||||
imgs = array_ops.stack(real_images_list + generated_images_list)
|
||||
|
||||
# Compute the activations using the memory-efficient `map_fn`.
|
||||
activations = functional_ops.map_fn(
|
||||
fn=classifier_fn,
|
||||
elems=imgs,
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
swap_memory=True,
|
||||
name='RunClassifier')
|
||||
|
||||
# Split the activations by the real and generated images.
|
||||
real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
|
||||
|
||||
# Ensure the activations have the right shapes.
|
||||
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
|
||||
gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
|
||||
real_a.shape.assert_has_rank(2)
|
||||
gen_a.shape.assert_has_rank(2)
|
||||
|
||||
# Compute mean and covariance matrices of activations.
|
||||
m = math_ops.reduce_mean(real_a, 0)
|
||||
m_v = math_ops.reduce_mean(gen_a, 0)
|
||||
dim = math_ops.to_float(array_ops.shape(m)[0])
|
||||
sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
|
||||
sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim
|
||||
|
||||
# Take matrix square root of the product of covariance matrices.
|
||||
sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))
|
||||
|
||||
# Compute the two components of FID.
|
||||
trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
|
||||
mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
|
||||
fid = trace + mean
|
||||
|
||||
return fid
|
||||
|
||||
|
||||
frechet_inception_distance = functools.partial(
|
||||
frechet_classifier_distance,
|
||||
classifier_fn=functools.partial(
|
||||
run_inception, output_tensor=INCEPTION_V3_FINAL_POOL))
|
@ -0,0 +1,316 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFGAN classifier_metrics."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl as classifier_metrics
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
mock = test.mock
|
||||
|
||||
|
||||
def _numpy_softmax(x):
|
||||
e_x = np.exp(x - np.max(x, axis=1)[:, None])
|
||||
return e_x / np.sum(e_x, axis=1)[:, None]
|
||||
|
||||
|
||||
def _expected_inception_score(logits):
|
||||
p = _numpy_softmax(logits)
|
||||
q = np.expand_dims(np.mean(p, 0), 0)
|
||||
per_example_logincscore = np.sum(p * (np.log(p) - np.log(q)), 1)
|
||||
return np.exp(np.mean(per_example_logincscore))
|
||||
|
||||
|
||||
def _approximate_matrix_sqrt(mat, eps=1e-8):
|
||||
s, u, v = np.linalg.svd(mat)
|
||||
si = np.where(s < eps, s, np.sqrt(s))
|
||||
return np.dot(np.dot(u, np.diag(si)), v.T)
|
||||
|
||||
|
||||
def _expected_fid(real_imgs, gen_imgs):
|
||||
real_imgs = np.asarray(real_imgs)
|
||||
gen_imgs = np.asarray(gen_imgs)
|
||||
m = np.mean(real_imgs, axis=0)
|
||||
m_v = np.mean(gen_imgs, axis=0)
|
||||
dim = float(m.shape[0])
|
||||
sigma = np.dot((real_imgs - m), (real_imgs - m).T) / dim
|
||||
sigma_v = np.dot((gen_imgs - m), (gen_imgs - m).T) / dim
|
||||
sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v))
|
||||
mean = np.square(np.linalg.norm(m - m_v))
|
||||
trace = np.trace(sigma + sigma_v - 2 * sqcc)
|
||||
fid = mean + trace
|
||||
return fid
|
||||
|
||||
|
||||
# A dummy GraphDef string with the minimum number of Ops.
|
||||
graphdef_string = """
|
||||
node {
|
||||
name: "inputs"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 299
|
||||
}
|
||||
dim {
|
||||
size: 299
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InceptionV3/Logits/SpatialSqueeze"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InceptionV3/Logits/AvgPool_1a_8x8/AvgPool"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 2048
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 24
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _get_dummy_graphdef():
|
||||
dummy_graphdef = graph_pb2.GraphDef()
|
||||
text_format.Merge(graphdef_string, dummy_graphdef)
|
||||
return dummy_graphdef
|
||||
|
||||
|
||||
def _run_with_mock(function, *args, **kwargs):
|
||||
with mock.patch.object(
|
||||
classifier_metrics,
|
||||
'get_graph_def_from_url_tarball') as mock_tarball_getter:
|
||||
mock_tarball_getter.return_value = _get_dummy_graphdef()
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
class ClassifierMetricsTest(test.TestCase):
|
||||
|
||||
def test_run_inception_graph(self):
|
||||
"""Test `run_inception` graph construction."""
|
||||
batch_size = 7
|
||||
img = array_ops.ones([batch_size, 299, 299, 3])
|
||||
logits = _run_with_mock(classifier_metrics.run_inception, img)
|
||||
|
||||
self.assertTrue(isinstance(logits, ops.Tensor))
|
||||
logits.shape.assert_is_compatible_with([batch_size, 1001])
|
||||
|
||||
# Check that none of the model variables are trainable.
|
||||
self.assertListEqual([], variables.trainable_variables())
|
||||
|
||||
def test_run_inception_graph_pool_output(self):
|
||||
"""Test `run_inception` graph construction with pool output."""
|
||||
batch_size = 3
|
||||
img = array_ops.ones([batch_size, 299, 299, 3])
|
||||
pool = _run_with_mock(
|
||||
classifier_metrics.run_inception, img,
|
||||
output_tensor=classifier_metrics.INCEPTION_V3_FINAL_POOL)
|
||||
|
||||
self.assertTrue(isinstance(pool, ops.Tensor))
|
||||
pool.shape.assert_is_compatible_with([batch_size, 2048])
|
||||
|
||||
# Check that none of the model variables are trainable.
|
||||
self.assertListEqual([], variables.trainable_variables())
|
||||
|
||||
def test_inception_score_graph(self):
|
||||
"""Test `inception_score` graph construction."""
|
||||
score = _run_with_mock(classifier_metrics.inception_score,
|
||||
array_ops.zeros([6, 299, 299, 3]), num_batches=3)
|
||||
self.assertTrue(isinstance(score, ops.Tensor))
|
||||
score.shape.assert_has_rank(0)
|
||||
|
||||
# Check that none of the model variables are trainable.
|
||||
self.assertListEqual([], variables.trainable_variables())
|
||||
|
||||
def test_frechet_inception_distance_graph(self):
|
||||
"""Test `frechet_inception_distance` graph construction."""
|
||||
img = array_ops.ones([7, 299, 299, 3])
|
||||
distance = _run_with_mock(
|
||||
classifier_metrics.frechet_inception_distance, img, img)
|
||||
|
||||
self.assertTrue(isinstance(distance, ops.Tensor))
|
||||
distance.shape.assert_has_rank(0)
|
||||
|
||||
# Check that none of the model variables are trainable.
|
||||
self.assertListEqual([], variables.trainable_variables())
|
||||
|
||||
def test_run_inception_multicall(self):
|
||||
"""Test that `run_inception` can be called multiple times."""
|
||||
for batch_size in (7, 3, 2):
|
||||
img = array_ops.ones([batch_size, 299, 299, 3])
|
||||
_run_with_mock(classifier_metrics.run_inception, img)
|
||||
|
||||
def test_invalid_input(self):
|
||||
"""Test that functions properly fail on invalid input."""
|
||||
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
|
||||
classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))
|
||||
|
||||
p = array_ops.zeros([8, 10])
|
||||
p_logits = array_ops.zeros([8, 10])
|
||||
q = array_ops.zeros([10])
|
||||
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
|
||||
classifier_metrics._kl_divergence(
|
||||
array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
|
||||
classifier_metrics._kl_divergence(
|
||||
p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
|
||||
classifier_metrics._kl_divergence(
|
||||
p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
|
||||
classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
|
||||
classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
|
||||
classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8]))
|
||||
|
||||
def test_inception_score_value(self):
|
||||
"""Test that `inception_score` gives the correct value."""
|
||||
logits = np.array([np.array([1, 2] * 500 + [4]),
|
||||
np.array([4, 5] * 500 + [6])])
|
||||
unused_image = array_ops.zeros([2, 299, 299, 3])
|
||||
incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
incscore_np = sess.run(incscore, {'concat:0': logits})
|
||||
|
||||
self.assertAllClose(_expected_inception_score(logits), incscore_np)
|
||||
|
||||
def test_frechet_inception_distance_value(self):
|
||||
"""Test that `frechet_inception_distance` gives the correct value."""
|
||||
np.random.seed(0)
|
||||
test_pool_real_a = np.random.randn(5, 2048)
|
||||
test_pool_gen_a = np.random.randn(5, 2048)
|
||||
unused_image = array_ops.zeros([5, 299, 299, 3])
|
||||
|
||||
pool_a = np.stack((test_pool_real_a, test_pool_gen_a))
|
||||
fid_op = _run_with_mock(classifier_metrics.frechet_inception_distance,
|
||||
unused_image, unused_image)
|
||||
activations_tensor = 'RunClassifier/TensorArrayStack/TensorArrayGatherV3:0'
|
||||
|
||||
with self.test_session() as sess:
|
||||
actual_fid = sess.run(fid_op, {activations_tensor: pool_a})
|
||||
|
||||
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
|
||||
self.assertAllClose(expected_fid, actual_fid, 0.01)
|
||||
|
||||
def test_preprocess_image_graph(self):
|
||||
"""Test `preprocess_image` graph construction."""
|
||||
incorrectly_sized_image = array_ops.zeros([520, 240, 3])
|
||||
correct_image = classifier_metrics.preprocess_image(
|
||||
image=incorrectly_sized_image)
|
||||
_run_with_mock(classifier_metrics.run_inception,
|
||||
array_ops.expand_dims(correct_image, 0))
|
||||
|
||||
def test_get_graph_def_from_url_tarball(self):
|
||||
"""Test `get_graph_def_from_url_tarball`."""
|
||||
# Write dummy binary GraphDef to tempfile.
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_file.write(_get_dummy_graphdef().SerializeToString())
|
||||
relative_path = os.path.relpath(tmp_file.name)
|
||||
|
||||
# Create gzip tarball.
|
||||
tar_dir = tempfile.mkdtemp()
|
||||
tar_filename = os.path.join(tar_dir, 'tmp.tar.gz')
|
||||
with tarfile.open(tar_filename, 'w:gz') as tar:
|
||||
tar.add(relative_path)
|
||||
|
||||
with mock.patch.object(classifier_metrics, 'urllib') as mock_urllib:
|
||||
mock_urllib.request.urlretrieve.return_value = tar_filename, None
|
||||
graph_def = classifier_metrics.get_graph_def_from_url_tarball(
|
||||
'unused_url', relative_path)
|
||||
|
||||
self.assertIsInstance(graph_def, graph_pb2.GraphDef)
|
||||
self.assertEqual(_get_dummy_graphdef(), graph_def)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
28
tensorflow/contrib/gan/python/eval/python/eval_utils.py
Normal file
28
tensorflow/contrib/gan/python/eval/python/eval_utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility file for visualizing generated images."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import eval_utils_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.eval.python.eval_utils_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = eval_utils_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
134
tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py
Normal file
134
tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility file for visualizing generated images."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"image_grid",
|
||||
"image_reshaper",
|
||||
]
|
||||
|
||||
|
||||
# TODO(joelshor): Make this a special case of `image_reshaper`.
|
||||
def image_grid(input_tensor, grid_shape, image_shape=(32, 32), num_channels=3):
|
||||
"""Arrange a minibatch of images into a grid to form a single image.
|
||||
|
||||
Args:
|
||||
input_tensor: Tensor. Minibatch of images to format, either 4D
|
||||
([batch size, height, width, num_channels]) or flattened
|
||||
([batch size, height * width * num_channels]).
|
||||
grid_shape: Sequence of int. The shape of the image grid,
|
||||
formatted as [grid_height, grid_width].
|
||||
image_shape: Sequence of int. The shape of a single image,
|
||||
formatted as [image_height, image_width].
|
||||
num_channels: int. The number of channels in an image.
|
||||
|
||||
Returns:
|
||||
Tensor representing a single image in which the input images have been
|
||||
arranged into a grid.
|
||||
|
||||
Raises:
|
||||
ValueError: The grid shape and minibatch size don't match, or the image
|
||||
shape and number of channels are incompatible with the input tensor.
|
||||
"""
|
||||
if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]):
|
||||
raise ValueError("Grid shape %s incompatible with minibatch size %i." %
|
||||
(grid_shape, int(input_tensor.shape[0])))
|
||||
if len(input_tensor.shape) == 2:
|
||||
num_features = image_shape[0] * image_shape[1] * num_channels
|
||||
if int(input_tensor.shape[1]) != num_features:
|
||||
raise ValueError("Image shape and number of channels incompatible with "
|
||||
"input tensor.")
|
||||
elif len(input_tensor.shape) == 4:
|
||||
if (int(input_tensor.shape[1]) != image_shape[0] or
|
||||
int(input_tensor.shape[2]) != image_shape[1] or
|
||||
int(input_tensor.shape[3]) != num_channels):
|
||||
raise ValueError("Image shape and number of channels incompatible with "
|
||||
"input tensor.")
|
||||
else:
|
||||
raise ValueError("Unrecognized input tensor format.")
|
||||
height, width = grid_shape[0] * image_shape[0], grid_shape[1] * image_shape[1]
|
||||
input_tensor = array_ops.reshape(
|
||||
input_tensor, tuple(grid_shape) + tuple(image_shape) + (num_channels,))
|
||||
input_tensor = array_ops.transpose(input_tensor, [0, 1, 3, 2, 4])
|
||||
input_tensor = array_ops.reshape(
|
||||
input_tensor, [grid_shape[0], width, image_shape[0], num_channels])
|
||||
input_tensor = array_ops.transpose(input_tensor, [0, 2, 1, 3])
|
||||
input_tensor = array_ops.reshape(
|
||||
input_tensor, [1, height, width, num_channels])
|
||||
return input_tensor
|
||||
|
||||
|
||||
def _validate_images(images):
|
||||
for img in images:
|
||||
img.shape.assert_has_rank(3)
|
||||
img.shape.assert_is_fully_defined()
|
||||
if img.shape[-1] not in (1, 3):
|
||||
raise ValueError("image_reshaper only supports 1 or 3 channel images.")
|
||||
|
||||
|
||||
# TODO(joelshor): Move the dimension logic from Python to Tensorflow.
|
||||
def image_reshaper(images, num_cols=None):
|
||||
"""A reshaped summary image.
|
||||
|
||||
Returns an image that will contain all elements in the list and will be
|
||||
laid out in a nearly-square tiling pattern (e.g. 11 images will lead to a
|
||||
3x4 tiled image).
|
||||
|
||||
Args:
|
||||
images: Image data to summarize. Can be an RGB or grayscale image, a list of
|
||||
such images, or a set of RGB images concatenated along the depth
|
||||
dimension. The shape of each image is assumed to be [batch_size,
|
||||
height, width, depth].
|
||||
num_cols: (Optional) If provided, this is the number of columns in the final
|
||||
output image grid. Otherwise, the number of columns is determined by
|
||||
the number of images.
|
||||
|
||||
Returns:
|
||||
A summary image matching the input with automatic tiling if needed.
|
||||
Output shape is [1, height, width, channels].
|
||||
"""
|
||||
if isinstance(images, ops.Tensor):
|
||||
images = array_ops.unstack(images)
|
||||
_validate_images(images)
|
||||
|
||||
num_images = len(images)
|
||||
num_columns = (num_cols if num_cols else
|
||||
int(math.ceil(math.sqrt(num_images))))
|
||||
num_rows = int(math.ceil(float(num_images) / num_columns))
|
||||
rows = [images[x:x+num_columns] for x in range(0, num_images, num_columns)]
|
||||
|
||||
# Add empty image tiles if the last row is incomplete.
|
||||
num_short = num_rows * num_columns - num_images
|
||||
assert num_short >= 0 and num_short < num_columns
|
||||
if num_short > 0:
|
||||
rows[-1].extend([array_ops.zeros_like(images[-1])] * num_short)
|
||||
|
||||
# Convert each row from a list of tensors to a single tensor.
|
||||
rows = [array_ops.concat(row, 1) for row in rows]
|
||||
|
||||
# Stack rows vertically.
|
||||
img = array_ops.concat(rows, 0)
|
||||
|
||||
return array_ops.expand_dims(img, 0)
|
48
tensorflow/contrib/gan/python/eval/python/eval_utils_test.py
Normal file
48
tensorflow/contrib/gan/python/eval/python/eval_utils_test.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for eval_utils_test."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import eval_utils_impl as eval_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class UtilsTest(test.TestCase):
|
||||
|
||||
def test_image_grid(self):
|
||||
eval_utils.image_grid(
|
||||
input_tensor=array_ops.zeros([25, 32, 32, 3]),
|
||||
grid_shape=(5, 5))
|
||||
|
||||
# TODO(joelshor): Add more `image_reshaper` tests.
|
||||
def test_image_reshaper_image_list(self):
|
||||
images = eval_utils.image_reshaper(
|
||||
images=array_ops.unstack(array_ops.zeros([25, 32, 32, 3])),
|
||||
num_cols=2)
|
||||
images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
|
||||
|
||||
def test_image_reshaper_image(self):
|
||||
images = eval_utils.image_reshaper(
|
||||
images=array_ops.zeros([25, 32, 32, 3]),
|
||||
num_cols=2)
|
||||
images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
28
tensorflow/contrib/gan/python/eval/python/summaries.py
Normal file
28
tensorflow/contrib/gan/python/eval/python/summaries.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common TFGAN summaries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import summaries_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.eval.python.summaries_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = summaries_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
157
tensorflow/contrib/gan/python/eval/python/summaries_impl.py
Normal file
157
tensorflow/contrib/gan/python/eval/python/summaries_impl.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common TFGAN summaries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.eval.python import eval_utils
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.losses import util as loss_util
|
||||
from tensorflow.python.summary import summary
|
||||
|
||||
__all__ = [
|
||||
'add_gan_model_image_summaries',
|
||||
'add_image_comparison_summaries',
|
||||
'add_gan_model_summaries',
|
||||
'add_regularization_loss_summaries',
|
||||
]
|
||||
|
||||
|
||||
def _assert_is_image(data):
|
||||
data.shape.assert_has_rank(4)
|
||||
data.shape[1:].assert_is_fully_defined()
|
||||
|
||||
|
||||
def add_gan_model_image_summaries(gan_model, grid_size=10):
|
||||
"""Adds image summaries for real and fake images.
|
||||
|
||||
Args:
|
||||
gan_model: A GANModel tuple.
|
||||
grid_size: The size of an image grid.
|
||||
|
||||
Raises:
|
||||
ValueError: If real and generated data aren't images.
|
||||
"""
|
||||
_assert_is_image(gan_model.real_data)
|
||||
_assert_is_image(gan_model.generated_data)
|
||||
|
||||
num_images = grid_size ** 2
|
||||
real_image_shape = gan_model.real_data.shape.as_list()[1:3]
|
||||
generated_image_shape = gan_model.generated_data.shape.as_list()[1:3]
|
||||
real_channels = gan_model.real_data.shape.as_list()[3]
|
||||
generated_channels = gan_model.generated_data.shape.as_list()[3]
|
||||
|
||||
summary.image(
|
||||
'real_data',
|
||||
eval_utils.image_grid(
|
||||
gan_model.real_data[:num_images],
|
||||
grid_shape=(grid_size, grid_size),
|
||||
image_shape=real_image_shape,
|
||||
num_channels=real_channels),
|
||||
max_outputs=1)
|
||||
summary.image(
|
||||
'generated_data',
|
||||
eval_utils.image_grid(
|
||||
gan_model.generated_data[:num_images],
|
||||
grid_shape=(grid_size, grid_size),
|
||||
image_shape=generated_image_shape,
|
||||
num_channels=generated_channels),
|
||||
max_outputs=1)
|
||||
add_gan_model_summaries(gan_model)
|
||||
|
||||
|
||||
def add_image_comparison_summaries(gan_model, num_comparisons=2,
|
||||
display_diffs=False):
|
||||
"""Adds image summaries to compare triplets of images.
|
||||
|
||||
The first image is the generator input, the second is the generator output,
|
||||
and the third is the real data. This style of comparison is useful for
|
||||
image translation problems, where the generator input is a corrupted image,
|
||||
the generator output is the reconstruction, and the real data is the target.
|
||||
|
||||
Args:
|
||||
gan_model: A GANModel tuple.
|
||||
num_comparisons: The number of image triplets to display.
|
||||
display_diffs: Also display the difference between generated and target.
|
||||
|
||||
Raises:
|
||||
ValueError: If real data, generated data, and generator inputs aren't
|
||||
images.
|
||||
ValueError: If the generator input, real, and generated data aren't all the
|
||||
same size.
|
||||
"""
|
||||
_assert_is_image(gan_model.generator_inputs)
|
||||
_assert_is_image(gan_model.generated_data)
|
||||
_assert_is_image(gan_model.real_data)
|
||||
|
||||
gan_model.generated_data.shape.assert_is_compatible_with(
|
||||
gan_model.generator_inputs.shape)
|
||||
gan_model.real_data.shape.assert_is_compatible_with(
|
||||
gan_model.generated_data.shape)
|
||||
|
||||
image_list = []
|
||||
image_list.extend(
|
||||
array_ops.unstack(gan_model.generator_inputs[:num_comparisons]))
|
||||
image_list.extend(
|
||||
array_ops.unstack(gan_model.generated_data[:num_comparisons]))
|
||||
image_list.extend(array_ops.unstack(gan_model.real_data[:num_comparisons]))
|
||||
if display_diffs:
|
||||
generated_list = array_ops.unstack(
|
||||
gan_model.generated_data[:num_comparisons])
|
||||
real_list = array_ops.unstack(gan_model.real_data[:num_comparisons])
|
||||
diffs = [
|
||||
math_ops.abs(math_ops.to_float(generated) - math_ops.to_float(real)) for
|
||||
generated, real in zip(generated_list, real_list)]
|
||||
image_list.extend(diffs)
|
||||
|
||||
# Reshape image and display.
|
||||
summary.image(
|
||||
'image_comparison',
|
||||
eval_utils.image_reshaper(image_list, num_cols=num_comparisons),
|
||||
max_outputs=1)
|
||||
|
||||
|
||||
def add_gan_model_summaries(gan_model):
|
||||
"""Adds typical GANModel summaries.
|
||||
|
||||
Args:
|
||||
gan_model: A GANModel tuple.
|
||||
"""
|
||||
with ops.name_scope('generator_variables'):
|
||||
for var in gan_model.generator_variables:
|
||||
summary.histogram(var.name, var)
|
||||
with ops.name_scope('discriminator_variables'):
|
||||
for var in gan_model.discriminator_variables:
|
||||
summary.histogram(var.name, var)
|
||||
|
||||
|
||||
def add_regularization_loss_summaries(gan_model):
|
||||
"""Adds summaries for a regularization losses..
|
||||
|
||||
Args:
|
||||
gan_model: A GANModel tuple.
|
||||
"""
|
||||
if gan_model.generator_scope:
|
||||
summary.scalar(
|
||||
'generator_regularization_loss',
|
||||
loss_util.get_regularization_loss(gan_model.generator_scope.name))
|
||||
if gan_model.discriminator_scope:
|
||||
summary.scalar(
|
||||
'discriminator_regularization_loss',
|
||||
loss_util.get_regularization_loss(gan_model.discriminator_scope.name))
|
96
tensorflow/contrib/gan/python/eval/python/summaries_test.py
Normal file
96
tensorflow/contrib/gan/python/eval/python/summaries_test.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFGAN summaries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
|
||||
|
||||
def generator_model(inputs):
|
||||
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
|
||||
|
||||
|
||||
def discriminator_model(inputs, _):
|
||||
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
|
||||
|
||||
|
||||
def get_gan_model():
|
||||
# TODO(joelshor): Find a better way of creating a variable scope.
|
||||
with variable_scope.variable_scope('generator') as gen_scope:
|
||||
pass
|
||||
with variable_scope.variable_scope('discriminator') as dis_scope:
|
||||
pass
|
||||
return namedtuples.GANModel(
|
||||
generator_inputs=array_ops.zeros([4, 32, 32, 3]),
|
||||
generated_data=array_ops.zeros([4, 32, 32, 3]),
|
||||
generator_variables=[variables.Variable(0), variables.Variable(1)],
|
||||
generator_scope=gen_scope,
|
||||
generator_fn=generator_model,
|
||||
real_data=array_ops.ones([4, 32, 32, 3]),
|
||||
discriminator_real_outputs=array_ops.ones([1, 2, 3]),
|
||||
discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
|
||||
discriminator_variables=[variables.Variable(0)],
|
||||
discriminator_scope=dis_scope,
|
||||
discriminator_fn=discriminator_model)
|
||||
|
||||
|
||||
class SummariesTest(test.TestCase):
|
||||
|
||||
def testAddGanModelImageSummaries(self):
|
||||
summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2)
|
||||
|
||||
self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||
with self.test_session(use_gpu=True):
|
||||
variables.global_variables_initializer().run()
|
||||
summary.merge_all().eval()
|
||||
|
||||
def testAddGanModelSummaries(self):
|
||||
summaries.add_gan_model_summaries(get_gan_model())
|
||||
|
||||
self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||
with self.test_session(use_gpu=True):
|
||||
variables.global_variables_initializer().run()
|
||||
summary.merge_all().eval()
|
||||
|
||||
def testAddRegularizationLossSummaries(self):
|
||||
summaries.add_regularization_loss_summaries(get_gan_model())
|
||||
|
||||
self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||
with self.test_session(use_gpu=True):
|
||||
summary.merge_all().eval()
|
||||
|
||||
# TODO(joelshor): Add correctness test.
|
||||
def testAddImageComparisonSummaries(self):
|
||||
summaries.add_image_comparison_summaries(
|
||||
get_gan_model(), display_diffs=True)
|
||||
|
||||
self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||
with self.test_session(use_gpu=True):
|
||||
summary.merge_all().eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -88,6 +88,7 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/image_ops_test.py"],
|
||||
additional_deps = [
|
||||
":distort_image_py",
|
||||
":image_py",
|
||||
":single_image_random_dot_stereograms_py",
|
||||
"//third_party/py/numpy",
|
||||
@ -99,6 +100,80 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_distort_image_ops.so",
|
||||
srcs = [
|
||||
"kernels/adjust_hsv_in_yiq_op.cc",
|
||||
"ops/distort_image_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"@protobuf_archive//:protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["distort_image_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "distort_image_ops",
|
||||
deps = [":distort_image_ops_op_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distort_image_ops_cc",
|
||||
srcs = [
|
||||
"kernels/adjust_hsv_in_yiq_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distort_image_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/distort_image_ops.py",
|
||||
],
|
||||
data = [":python/ops/_distort_image_ops.so"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":distort_image_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "distort_image_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/distort_image_ops_test.py"],
|
||||
additional_deps = [
|
||||
":distort_image_py",
|
||||
":image_py",
|
||||
":single_image_random_dot_stereograms_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_single_image_random_dot_stereograms.so",
|
||||
srcs = [
|
||||
|
@ -16,11 +16,14 @@
|
||||
|
||||
### API
|
||||
|
||||
This module provides functions for image manipulation; currently, only
|
||||
This module provides functions for image manipulation; currently, chrominance
|
||||
transformas (including changing saturation and hue) in YIQ space and
|
||||
projective transforms (including rotation) are supported.
|
||||
|
||||
@@angles_to_projective_transforms
|
||||
@@compose_transforms
|
||||
@@adjust_yiq_hsv
|
||||
@@random_yiq_hsv
|
||||
@@rotate
|
||||
@@transform
|
||||
@@bipartite_match
|
||||
@ -31,6 +34,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq
|
||||
from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq
|
||||
|
||||
from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms
|
||||
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
|
||||
from tensorflow.contrib.image.python.ops.image_ops import rotate
|
||||
@ -39,5 +45,6 @@ from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms imp
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
172
tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
Normal file
172
tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
class AdjustHsvInYiqOpBase : public OpKernel {
|
||||
protected:
|
||||
explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
struct ComputeOptions {
|
||||
const Tensor* input = nullptr;
|
||||
const Tensor* delta_h = nullptr;
|
||||
const Tensor* scale_s = nullptr;
|
||||
const Tensor* scale_v = nullptr;
|
||||
Tensor* output = nullptr;
|
||||
int64 channel_count = 0;
|
||||
};
|
||||
|
||||
virtual void DoCompute(OpKernelContext* context,
|
||||
const ComputeOptions& options) = 0;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& delta_h = context->input(1);
|
||||
const Tensor& scale_s = context->input(2);
|
||||
const Tensor& scale_v = context->input(3);
|
||||
OP_REQUIRES(context, input.dims() >= 3,
|
||||
errors::InvalidArgument("input must be at least 3-D, got shape",
|
||||
input.shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
|
||||
errors::InvalidArgument("delta_h must be scalar: ",
|
||||
delta_h.shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
|
||||
errors::InvalidArgument("scale_s must be scalar: ",
|
||||
scale_s.shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
|
||||
errors::InvalidArgument("scale_v must be scalar: ",
|
||||
scale_v.shape().DebugString()));
|
||||
auto channels = input.dim_size(input.dims() - 1);
|
||||
OP_REQUIRES(
|
||||
context, channels == 3,
|
||||
errors::InvalidArgument("input must have 3 channels but instead has ",
|
||||
channels, " channels."));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
|
||||
if (input.NumElements() > 0) {
|
||||
const int64 channel_count = input.NumElements() / channels;
|
||||
ComputeOptions options;
|
||||
options.input = &input;
|
||||
options.delta_h = &delta_h;
|
||||
options.scale_s = &scale_s;
|
||||
options.scale_v = &scale_v;
|
||||
options.output = output;
|
||||
options.channel_count = channel_count;
|
||||
DoCompute(context, options);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Device>
|
||||
class AdjustHsvInYiqOp;
|
||||
|
||||
template <>
|
||||
class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
|
||||
public:
|
||||
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
|
||||
: AdjustHsvInYiqOpBase(context) {}
|
||||
|
||||
void DoCompute(OpKernelContext* context,
|
||||
const ComputeOptions& options) override {
|
||||
const Tensor* input = options.input;
|
||||
Tensor* output = options.output;
|
||||
const int64 channel_count = options.channel_count;
|
||||
static const int kChannelSize = 3;
|
||||
auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
|
||||
const float delta_h = options.delta_h->scalar<float>()();
|
||||
const float scale_s = options.scale_s->scalar<float>()();
|
||||
const float scale_v = options.scale_v->scalar<float>()();
|
||||
auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
|
||||
const int kCostPerChannel = 10;
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*context->device()->tensorflow_cpu_worker_threads();
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
|
||||
kCostPerChannel,
|
||||
[channel_count, &input_data, &output_data, delta_h, scale_s, scale_v](
|
||||
int64 start_channel, int64 end_channel) {
|
||||
// Using approximate linear transfomation described in:
|
||||
// https://beesbuzz.biz/code/hsv_color_transforms.php
|
||||
/** Get the constants from sympy
|
||||
from sympy import Matrix
|
||||
from sympy.abc import u, w
|
||||
# Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ
|
||||
tyiq = Matrix([[0.299, 0.587, 0.114],
|
||||
[0.596, -0.274, -0.322],
|
||||
[0.211, -0.523, 0.312]])
|
||||
# Hue rotation matrix in YIQ space.
|
||||
hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu])
|
||||
m = tyiq.inv() * hue_proj * tyiq
|
||||
**/
|
||||
// TODO(huangyp): directly compute the projection matrix from tyiq.
|
||||
static const float t[kChannelSize][kChannelSize][kChannelSize] = {
|
||||
{{.299, .701, .16862179492229},
|
||||
{.587, -.587, .329804745287403},
|
||||
{.114, -.114, -0.498426540209694}},
|
||||
{{.299, -.299, -.327963394172371},
|
||||
{.587, .413, .0346106879248821},
|
||||
{.114, -.114, .293352706247489}},
|
||||
{{.299, -.299, 1.24646136576682},
|
||||
{.587, -.587, -1.04322888291964},
|
||||
{.114, .886, -.203232482847173}}};
|
||||
float m[kChannelSize][kChannelSize] = {{0.}};
|
||||
float su = scale_s * std::cos(delta_h);
|
||||
float sw = scale_s * std::sin(delta_h);
|
||||
for (int q_index = 0; q_index < kChannelSize; q_index++) {
|
||||
for (int p_index = 0; p_index < kChannelSize; p_index++) {
|
||||
m[q_index][p_index] = scale_v * (t[q_index][p_index][0] +
|
||||
t[q_index][p_index][1] * su +
|
||||
t[q_index][p_index][2] * sw);
|
||||
}
|
||||
}
|
||||
// Applying projection matrix to input RGB vectors.
|
||||
const float* p = input_data.data() + start_channel * kChannelSize;
|
||||
float* q = output_data.data() + start_channel * kChannelSize;
|
||||
for (int i = start_channel; i < end_channel; i++) {
|
||||
for (int q_index = 0; q_index < kChannelSize; q_index++) {
|
||||
q[q_index] = 0;
|
||||
for (int p_index = 0; p_index < kChannelSize; p_index++) {
|
||||
q[q_index] += m[q_index][p_index] * p[p_index];
|
||||
}
|
||||
}
|
||||
p += kChannelSize;
|
||||
q += kChannelSize;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU),
|
||||
AdjustHsvInYiqOp<CPUDevice>);
|
||||
|
||||
// TODO(huangyp): add the GPU kernel
|
||||
} // namespace tensorflow
|
60
tensorflow/contrib/image/ops/distort_image_ops.cc
Normal file
60
tensorflow/contrib/image/ops/distort_image_ops.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("AdjustHsvInYiq")
|
||||
.Input("images: T")
|
||||
.Input("delta_h: float")
|
||||
.Input("scale_s: float")
|
||||
.Input("scale_v: float")
|
||||
.Output("output: T")
|
||||
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
|
||||
})
|
||||
.Doc(R"Doc(
|
||||
Adjust the YIQ hue of one or more images.
|
||||
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpretted as channels, and must be three.
|
||||
|
||||
We used linear transfomation described in:
|
||||
beesbuzz.biz/code/hsv_color_transforms.php
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into YIQ space, rotated around the Y channel by
|
||||
delta_h in radians, multiplying the chrominance channels (I, Q) by scale_s,
|
||||
multiplying all channels (Y, I, Q) by scale_v, and then remapped back to RGB
|
||||
colorspace. Each operation described above is a linear transformation.
|
||||
|
||||
images: Images to adjust. At least 3-D.
|
||||
delta_h: A float scale that represents the hue rotation amount, in radians.
|
||||
Although delta_h can be any float value.
|
||||
scale_s: A float scale that represents the factor to multiply the saturation by.
|
||||
scale_s needs to be non-negative.
|
||||
scale_v: A float scale that represents the factor to multiply the value by.
|
||||
scale_v needs to be non-negative.
|
||||
output: The hsv-adjusted image or images. No clipping will be done in this op.
|
||||
The client can clip them using additional ops in their graph.
|
||||
)Doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,338 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for python distort_image_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.image.python.ops import distort_image_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(huangyp): also measure the differences between AdjustHsvInYiq and
|
||||
# AdjustHsv in core.
|
||||
class AdjustHueInYiqTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _adjust_hue_in_yiq_np(self, x_np, delta_h):
|
||||
"""Rotate hue in YIQ space.
|
||||
|
||||
Mathematically we first convert rgb color to yiq space, rotate the hue
|
||||
degrees, and then convert back to rgb.
|
||||
|
||||
Args:
|
||||
x_np: input x with last dimension = 3.
|
||||
delta_h: degree of hue rotation, in radians.
|
||||
|
||||
Returns:
|
||||
Adjusted y with the same shape as x_np.
|
||||
"""
|
||||
self.assertEqual(x_np.shape[-1], 3)
|
||||
x_v = x_np.reshape([-1, 3])
|
||||
y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
|
||||
u = np.cos(delta_h)
|
||||
w = np.sin(delta_h)
|
||||
# Projection matrix from RGB to YIQ. Numbers from wikipedia
|
||||
# https://en.wikipedia.org/wiki/YIQ
|
||||
tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.322],
|
||||
[0.211, -0.523, 0.312]])
|
||||
y_v = np.dot(x_v, tyiq.T)
|
||||
# Hue rotation matrix in YIQ space.
|
||||
hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
|
||||
y_v = np.dot(y_v, hue_rotation.T)
|
||||
# Projecting back to RGB space.
|
||||
y_v = np.dot(y_v, np.linalg.inv(tyiq).T)
|
||||
return y_v.reshape(x_np.shape)
|
||||
|
||||
def _adjust_hue_in_yiq_tf(self, x_np, delta_h):
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant(x_np)
|
||||
y = distort_image_ops.adjust_hsv_in_yiq(x, delta_h, 1, 1)
|
||||
y_tf = y.eval()
|
||||
return y_tf
|
||||
|
||||
def test_adjust_random_hue_in_yiq(self):
|
||||
x_shapes = [
|
||||
[2, 2, 3],
|
||||
[4, 2, 3],
|
||||
[2, 4, 3],
|
||||
[2, 5, 3],
|
||||
[1000, 1, 3],
|
||||
]
|
||||
test_styles = [
|
||||
'all_random',
|
||||
'rg_same',
|
||||
'rb_same',
|
||||
'gb_same',
|
||||
'rgb_same',
|
||||
]
|
||||
for x_shape in x_shapes:
|
||||
for test_style in test_styles:
|
||||
x_np = np.random.rand(*x_shape) * 255.
|
||||
delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
|
||||
if test_style == 'all_random':
|
||||
pass
|
||||
elif test_style == 'rg_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
elif test_style == 'rb_same':
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
elif test_style == 'gb_same':
|
||||
x_np[..., 2] = x_np[..., 1]
|
||||
elif test_style == 'rgb_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
else:
|
||||
raise AssertionError('Invalid test style: %s' % (test_style))
|
||||
y_np = self._adjust_hue_in_yiq_np(x_np, delta_h)
|
||||
y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
|
||||
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
|
||||
|
||||
def test_invalid_shapes(self):
|
||||
x_np = np.random.rand(2, 3) * 255.
|
||||
delta_h = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
|
||||
self._adjust_hue_in_yiq_tf(x_np, delta_h)
|
||||
x_np = np.random.rand(4, 2, 4) * 255.
|
||||
delta_h = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesOpError('input must have 3 channels but instead has '
|
||||
'4 channels'):
|
||||
self._adjust_hue_in_yiq_tf(x_np, delta_h)
|
||||
|
||||
|
||||
class AdjustValueInYiqTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _adjust_value_in_yiq_np(self, x_np, scale):
|
||||
return x_np * scale
|
||||
|
||||
def _adjust_value_in_yiq_tf(self, x_np, scale):
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant(x_np)
|
||||
y = distort_image_ops.adjust_hsv_in_yiq(x, 0, 1, scale)
|
||||
y_tf = y.eval()
|
||||
return y_tf
|
||||
|
||||
def test_adjust_random_value_in_yiq(self):
|
||||
x_shapes = [
|
||||
[2, 2, 3],
|
||||
[4, 2, 3],
|
||||
[2, 4, 3],
|
||||
[2, 5, 3],
|
||||
[1000, 1, 3],
|
||||
]
|
||||
test_styles = [
|
||||
'all_random',
|
||||
'rg_same',
|
||||
'rb_same',
|
||||
'gb_same',
|
||||
'rgb_same',
|
||||
]
|
||||
for x_shape in x_shapes:
|
||||
for test_style in test_styles:
|
||||
x_np = np.random.rand(*x_shape) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
if test_style == 'all_random':
|
||||
pass
|
||||
elif test_style == 'rg_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
elif test_style == 'rb_same':
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
elif test_style == 'gb_same':
|
||||
x_np[..., 2] = x_np[..., 1]
|
||||
elif test_style == 'rgb_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
else:
|
||||
raise AssertionError('Invalid test style: %s' % (test_style))
|
||||
y_np = self._adjust_value_in_yiq_np(x_np, scale)
|
||||
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
|
||||
self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
|
||||
|
||||
def test_invalid_shapes(self):
|
||||
x_np = np.random.rand(2, 3) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
|
||||
self._adjust_value_in_yiq_tf(x_np, scale)
|
||||
x_np = np.random.rand(4, 2, 4) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesOpError('input must have 3 channels but instead has '
|
||||
'4 channels'):
|
||||
self._adjust_value_in_yiq_tf(x_np, scale)
|
||||
|
||||
|
||||
class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _adjust_saturation_in_yiq_tf(self, x_np, scale):
|
||||
with self.test_session(use_gpu=True):
|
||||
x = constant_op.constant(x_np)
|
||||
y = distort_image_ops.adjust_hsv_in_yiq(x, 0, scale, 1)
|
||||
y_tf = y.eval()
|
||||
return y_tf
|
||||
|
||||
def _adjust_saturation_in_yiq_np(self, x_np, scale):
|
||||
"""Adjust saturation using linear interpolation."""
|
||||
rgb_weights = np.array([0.299, 0.587, 0.114])
|
||||
gray = np.sum(x_np * rgb_weights, axis=-1, keepdims=True)
|
||||
y_v = x_np * scale + gray * (1 - scale)
|
||||
return y_v
|
||||
|
||||
def test_adjust_random_saturation_in_yiq(self):
|
||||
x_shapes = [
|
||||
[2, 2, 3],
|
||||
[4, 2, 3],
|
||||
[2, 4, 3],
|
||||
[2, 5, 3],
|
||||
[1000, 1, 3],
|
||||
]
|
||||
test_styles = [
|
||||
'all_random',
|
||||
'rg_same',
|
||||
'rb_same',
|
||||
'gb_same',
|
||||
'rgb_same',
|
||||
]
|
||||
with self.test_session():
|
||||
for x_shape in x_shapes:
|
||||
for test_style in test_styles:
|
||||
x_np = np.random.rand(*x_shape) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
if test_style == 'all_random':
|
||||
pass
|
||||
elif test_style == 'rg_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
elif test_style == 'rb_same':
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
elif test_style == 'gb_same':
|
||||
x_np[..., 2] = x_np[..., 1]
|
||||
elif test_style == 'rgb_same':
|
||||
x_np[..., 1] = x_np[..., 0]
|
||||
x_np[..., 2] = x_np[..., 0]
|
||||
else:
|
||||
raise AssertionError('Invalid test style: %s' % (test_style))
|
||||
y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale)
|
||||
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
|
||||
self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5)
|
||||
|
||||
def test_invalid_shapes(self):
|
||||
x_np = np.random.rand(2, 3) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
|
||||
self._adjust_saturation_in_yiq_tf(x_np, scale)
|
||||
x_np = np.random.rand(4, 2, 4) * 255.
|
||||
scale = np.random.rand() * 2.0 - 1.0
|
||||
with self.assertRaisesOpError('input must have 3 channels but instead has '
|
||||
'4 channels'):
|
||||
self._adjust_saturation_in_yiq_tf(x_np, scale)
|
||||
|
||||
|
||||
class AdjustHueInYiqBenchmark(test.Benchmark):
|
||||
|
||||
def _benchmark_adjust_hue_in_yiq(self, device, cpu_count):
|
||||
image_shape = [299, 299, 3]
|
||||
warmup_rounds = 100
|
||||
benchmark_rounds = 1000
|
||||
config = config_pb2.ConfigProto()
|
||||
if cpu_count is not None:
|
||||
config.inter_op_parallelism_threads = 1
|
||||
config.intra_op_parallelism_threads = cpu_count
|
||||
with session.Session('', graph=ops.Graph(), config=config) as sess:
|
||||
with ops.device(device):
|
||||
inputs = variables.Variable(
|
||||
random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
|
||||
trainable=False,
|
||||
dtype=dtypes.float32)
|
||||
delta = constant_op.constant(0.1, dtype=dtypes.float32)
|
||||
outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, delta, 1, 1)
|
||||
run_op = control_flow_ops.group(outputs)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
for i in xrange(warmup_rounds + benchmark_rounds):
|
||||
if i == warmup_rounds:
|
||||
start = time.time()
|
||||
sess.run(run_op)
|
||||
end = time.time()
|
||||
step_time = (end - start) / benchmark_rounds
|
||||
tag = device + '_%s' % (cpu_count if cpu_count is not None else 'all')
|
||||
print('benchmarkadjust_hue_in_yiq_299_299_3_%s step_time: %.2f us' %
|
||||
(tag, step_time * 1e6))
|
||||
self.report_benchmark(
|
||||
name='benchmarkadjust_hue_in_yiq_299_299_3_%s' % (tag),
|
||||
iters=benchmark_rounds,
|
||||
wall_time=step_time)
|
||||
|
||||
def benchmark_adjust_hue_in_yiqCpu1(self):
|
||||
self._benchmark_adjust_hue_in_yiq('/cpu:0', 1)
|
||||
|
||||
def benchmark_adjust_hue_in_yiqCpuAll(self):
|
||||
self._benchmark_adjust_hue_in_yiq('/cpu:0', None)
|
||||
|
||||
|
||||
class AdjustSaturationInYiqBenchmark(test.Benchmark):
|
||||
|
||||
def _benchmark_adjust_saturation_in_yiq(self, device, cpu_count):
|
||||
image_shape = [299, 299, 3]
|
||||
warmup_rounds = 100
|
||||
benchmark_rounds = 1000
|
||||
config = config_pb2.ConfigProto()
|
||||
if cpu_count is not None:
|
||||
config.inter_op_parallelism_threads = 1
|
||||
config.intra_op_parallelism_threads = cpu_count
|
||||
with session.Session('', graph=ops.Graph(), config=config) as sess:
|
||||
with ops.device(device):
|
||||
inputs = variables.Variable(
|
||||
random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
|
||||
trainable=False,
|
||||
dtype=dtypes.float32)
|
||||
scale = constant_op.constant(0.1, dtype=dtypes.float32)
|
||||
outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, 0, scale, 1)
|
||||
run_op = control_flow_ops.group(outputs)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
for _ in xrange(warmup_rounds):
|
||||
sess.run(run_op)
|
||||
start = time.time()
|
||||
for _ in xrange(benchmark_rounds):
|
||||
sess.run(run_op)
|
||||
end = time.time()
|
||||
step_time = (end - start) / benchmark_rounds
|
||||
tag = '%s' % (cpu_count) if cpu_count is not None else '_all'
|
||||
print('benchmarkAdjustSaturationInYiq_299_299_3_cpu%s step_time: %.2f us' %
|
||||
(tag, step_time * 1e6))
|
||||
self.report_benchmark(
|
||||
name='benchmarkAdjustSaturationInYiq_299_299_3_cpu%s' % (tag),
|
||||
iters=benchmark_rounds,
|
||||
wall_time=step_time)
|
||||
|
||||
def benchmark_adjust_saturation_in_yiq_cpu1(self):
|
||||
self._benchmark_adjust_saturation_in_yiq('/cpu:0', 1)
|
||||
|
||||
def benchmark_adjust_saturation_in_yiq_cpu_all(self):
|
||||
self._benchmark_adjust_saturation_in_yiq('/cpu:0', None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
138
tensorflow/contrib/image/python/ops/distort_image_ops.py
Normal file
138
tensorflow/contrib/image/python/ops/distort_image_ops.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python layer for distort_image_ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_distort_image_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile('_distort_image_ops.so'))
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def random_hsv_in_yiq(image,
|
||||
max_delta_hue=0,
|
||||
lower_saturation=1,
|
||||
upper_saturation=1,
|
||||
lower_value=1,
|
||||
upper_value=1,
|
||||
seed=None):
|
||||
"""Adjust hue, saturation, value of an RGB image randomly in YIQ color space.
|
||||
|
||||
Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly
|
||||
picked in the interval `[-max_delta_hue, max_delta_hue]`, a `scale_saturation`
|
||||
randomly picked in the interval `[lower_saturation, upper_saturation]`, and
|
||||
a `scale_value` randomly picked in the interval
|
||||
`[lower_saturation, upper_saturation]`.
|
||||
|
||||
Args:
|
||||
image: RGB image or images. Size of the last dimension must be 3.
|
||||
max_delta_hue: float. Maximum value for the random delta_hue. Passing 0
|
||||
disables adjusting hue.
|
||||
lower_saturation: float. Lower bound for the random scale_saturation.
|
||||
upper_saturation: float. Upper bound for the random scale_saturation.
|
||||
lower_value: float. Lower bound for the random scale_value.
|
||||
upper_value: float. Upper bound for the random scale_value.
|
||||
seed: An operation-specific seed. It will be used in conjunction
|
||||
with the graph-level seed to determine the real seeds that will be
|
||||
used in this operation. Please see the documentation of
|
||||
set_random_seed for its interaction with the graph-level random seed.
|
||||
|
||||
Returns:
|
||||
3-D float tensor of shape `[height, width, channels]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `max_delta`, `lower_saturation`, `upper_saturation`,
|
||||
`lower_value`, or `upper_Value` is invalid.
|
||||
"""
|
||||
if max_delta_hue < 0:
|
||||
raise ValueError('max_delta must be non-negative.')
|
||||
|
||||
if lower_saturation < 0:
|
||||
raise ValueError('lower_saturation must be non-negative.')
|
||||
|
||||
if lower_value < 0:
|
||||
raise ValueError('lower_value must be non-negative.')
|
||||
|
||||
if lower_saturation > upper_saturation:
|
||||
raise ValueError('lower_saturation must be < upper_saturation.')
|
||||
|
||||
if lower_value > upper_value:
|
||||
raise ValueError('lower_value must be < upper_value.')
|
||||
|
||||
if max_delta_hue == 0:
|
||||
delta_hue = 0
|
||||
else:
|
||||
delta_hue = random_ops.random_uniform(
|
||||
[], -max_delta_hue, max_delta_hue, seed=seed)
|
||||
if lower_saturation == upper_saturation:
|
||||
scale_saturation = lower_saturation
|
||||
else:
|
||||
scale_saturation = random_ops.random_uniform(
|
||||
[], lower_saturation, upper_saturation, seed=seed)
|
||||
if lower_value == upper_value:
|
||||
scale_value = lower_value
|
||||
else:
|
||||
scale_value = random_ops.random_uniform(
|
||||
[], lower_value, upper_value, seed=seed)
|
||||
return adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value)
|
||||
|
||||
|
||||
def adjust_hsv_in_yiq(image,
|
||||
delta_hue=0,
|
||||
scale_saturation=1,
|
||||
scale_value=1,
|
||||
name=None):
|
||||
"""Adjust hue, saturation, value of an RGB image in YIQ color space.
|
||||
|
||||
This is a convenience method that converts an RGB image to float
|
||||
representation, converts it to YIQ, rotates the color around the Y channel by
|
||||
delta_hue in radians, scales the chrominance channels (I, Q) by
|
||||
scale_saturation, scales all channels (Y, I, Q) by scale_value,
|
||||
converts back to RGB, and then back to the original data type.
|
||||
|
||||
`image` is an RGB image. The image hue is adjusted by converting the
|
||||
image to YIQ, rotating around the luminance channel (Y) by
|
||||
`delta_hue` in radians, multiplying the chrominance channels (I, Q) by
|
||||
`scale_saturation`, and multiplying all channels (Y, I, Q) by
|
||||
`scale_value`. The image is then converted back to RGB.
|
||||
|
||||
Args:
|
||||
image: RGB image or images. Size of the last dimension must be 3.
|
||||
delta_hue: float, the hue rotation amount, in radians.
|
||||
scale_saturation: float, factor to multiply the saturation by.
|
||||
scale_value: float, factor to multiply the value by.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
Adjusted image(s), same shape and DType as `image`.
|
||||
"""
|
||||
with ops.name_scope(name, 'adjust_hsv_in_yiq', [image]) as name:
|
||||
image = ops.convert_to_tensor(image, name='image')
|
||||
# Remember original dtype to so we can convert back if needed
|
||||
orig_dtype = image.dtype
|
||||
flt_image = image_ops.convert_image_dtype(image, dtypes.float32)
|
||||
|
||||
rgb_altered = _distort_image_ops.adjust_hsv_in_yiq(
|
||||
flt_image, delta_hue, scale_saturation, scale_value)
|
||||
|
||||
return image_ops.convert_image_dtype(rgb_altered, orig_dtype)
|
@ -1678,9 +1678,14 @@ class _MultiHead(Head):
|
||||
ModelFnOps that merges all heads for TRAIN.
|
||||
"""
|
||||
losses = []
|
||||
metrics = {}
|
||||
additional_train_ops = []
|
||||
for m in all_model_fn_ops:
|
||||
losses.append(m.loss)
|
||||
if m.eval_metric_ops is not None:
|
||||
for k, v in six.iteritems(m.eval_metric_ops):
|
||||
# metrics["%s/%s" % (k, head_name)] = v
|
||||
metrics[k] = v
|
||||
additional_train_ops.append(m.train_op)
|
||||
loss = self._loss_merger(losses)
|
||||
|
||||
@ -1689,7 +1694,8 @@ class _MultiHead(Head):
|
||||
return model_fn.ModelFnOps(
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
loss=loss,
|
||||
train_op=train_op)
|
||||
train_op=train_op,
|
||||
eval_metric_ops=metrics)
|
||||
|
||||
def _merge_infer(self, all_model_fn_ops):
|
||||
"""Merges list of ModelFnOps for inference.
|
||||
|
@ -1703,7 +1703,7 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertIsNone(model_fn_ops.predictions)
|
||||
self.assertIsNotNone(model_fn_ops.loss)
|
||||
self.assertIsNotNone(model_fn_ops.train_op)
|
||||
self.assertFalse(model_fn_ops.eval_metric_ops)
|
||||
self.assertTrue(model_fn_ops.eval_metric_ops)
|
||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||
|
||||
with session.Session() as sess:
|
||||
@ -1728,7 +1728,7 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertIsNone(model_fn_ops.predictions)
|
||||
self.assertIsNotNone(model_fn_ops.loss)
|
||||
self.assertIsNotNone(model_fn_ops.train_op)
|
||||
self.assertFalse(model_fn_ops.eval_metric_ops)
|
||||
self.assertTrue(model_fn_ops.eval_metric_ops)
|
||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||
|
||||
with session.Session() as sess:
|
||||
@ -1755,7 +1755,7 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertIsNone(model_fn_ops.predictions)
|
||||
self.assertIsNotNone(model_fn_ops.loss)
|
||||
self.assertIsNotNone(model_fn_ops.train_op)
|
||||
self.assertFalse(model_fn_ops.eval_metric_ops)
|
||||
self.assertTrue(model_fn_ops.eval_metric_ops)
|
||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||
|
||||
with session.Session() as sess:
|
||||
|
@ -113,7 +113,7 @@ struct GatherTree<CPUDevice, int32> {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
int32 seq_len_b = sequence_length(batch, beam);
|
||||
if (seq_len_b == 0) {
|
||||
if (seq_len_b <= 0) {
|
||||
continue;
|
||||
}
|
||||
beams(seq_len_b - 1, batch, beam) =
|
||||
|
@ -33,7 +33,10 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
|
||||
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
|
||||
if (seq_len_b <= 0) continue;
|
||||
|
||||
#define GET_IX(time_ix, beam_ix) \
|
||||
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
|
||||
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
|
||||
|
@ -155,7 +155,7 @@ void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
|
||||
// retrieving the sizes from the wrapped allocator removes the
|
||||
// executor's reference to it, so allocator_pair.second must not
|
||||
// be dereferenced again after this statement
|
||||
auto sizes = allocator_pair.second->GetSizesAndUnRef();
|
||||
const auto sizes = allocator_pair.second->GetSizesAndUnRef();
|
||||
memory->set_allocator_name(allocator_pair.first->Name());
|
||||
memory->set_total_bytes(std::get<0>(sizes));
|
||||
memory->set_peak_bytes(std::get<1>(sizes));
|
||||
@ -1373,7 +1373,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
|
||||
|
||||
for (const Edge* out_edge : curr_node->out_edges()) {
|
||||
Node* out = out_edge->dst();
|
||||
int out_id = out->id();
|
||||
const int out_id = out->id();
|
||||
|
||||
// Add to ready queue if not visited.
|
||||
bool is_visited = visited[out_id];
|
||||
@ -1417,7 +1417,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
|
||||
|
||||
// Ask the device to fill in the device context map.
|
||||
Device* device = impl_->params_.device;
|
||||
Status fill_status = device->FillContextMap(graph, &device_context_map_);
|
||||
const Status fill_status =
|
||||
device->FillContextMap(graph, &device_context_map_);
|
||||
if (!fill_status.ok()) {
|
||||
done(fill_status);
|
||||
return;
|
||||
@ -1525,7 +1526,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
inline_ready.pop_front();
|
||||
const Node* node = tagged_node.node;
|
||||
FrameState* input_frame = tagged_node.input_frame;
|
||||
int64 input_iter = tagged_node.input_iter;
|
||||
const int64 input_iter = tagged_node.input_iter;
|
||||
const int id = node->id();
|
||||
const NodeItem& item = *gview.node(id);
|
||||
|
||||
@ -1637,7 +1638,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
|
||||
accessed);
|
||||
}
|
||||
bool completed =
|
||||
const bool completed =
|
||||
NodeDone(s, state->item->node, ready, stats, nullptr);
|
||||
delete state;
|
||||
if (completed) Finish();
|
||||
@ -1803,7 +1804,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
for (int i = 0; i < item.num_outputs; ++i) {
|
||||
TensorValue val = ctx->release_output(i);
|
||||
const TensorValue val = ctx->release_output(i);
|
||||
if (*ctx->is_output_dead() || val.tensor == nullptr) {
|
||||
// Unless it's a Switch or a Recv, the node must produce a
|
||||
// tensor value at i-th output.
|
||||
@ -1893,7 +1894,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
TaggedNodeSeq* ready) {
|
||||
const Node* node = tagged_node.node;
|
||||
FrameState* input_frame = tagged_node.input_frame;
|
||||
int64 input_iter = tagged_node.input_iter;
|
||||
const int64 input_iter = tagged_node.input_iter;
|
||||
const bool is_dead = tagged_node.is_dead;
|
||||
|
||||
// Propagates outputs along out edges, and puts newly ready nodes
|
||||
@ -1913,7 +1914,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
&impl_->gview_, input_iter, ready);
|
||||
} else if (item->is_enter) {
|
||||
bool is_constant;
|
||||
Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
|
||||
const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
|
||||
DCHECK(s.ok()) << s;
|
||||
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
|
||||
output_iter = 0;
|
||||
@ -1983,7 +1984,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
// completion of this node makes its frame completed.
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = input_frame->parent_frame;
|
||||
int64 parent_iter = input_frame->parent_iter;
|
||||
const int64 parent_iter = input_frame->parent_iter;
|
||||
DeleteFrame(input_frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
@ -2026,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
}
|
||||
|
||||
bool completed = false;
|
||||
size_t ready_size = ready.size();
|
||||
const size_t ready_size = ready.size();
|
||||
if (ready_size == 0 || !s.ok()) {
|
||||
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
|
||||
} else if (ready_size > 1) {
|
||||
@ -2166,7 +2167,7 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
|
||||
const std::vector<const Node*>* nodes = frame->nodes;
|
||||
// Dump any waiting nodes that are holding on to tensors.
|
||||
for (const Node* node : *nodes) {
|
||||
int node_id = node->id();
|
||||
const int node_id = node->id();
|
||||
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
|
||||
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||
@ -2175,14 +2176,14 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
|
||||
}
|
||||
// Then the active nodes.
|
||||
for (const Node* node : *nodes) {
|
||||
int node_id = node->id();
|
||||
const int node_id = node->id();
|
||||
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
|
||||
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(node_id, iteration->input_tensors);
|
||||
}
|
||||
}
|
||||
// Show all input tensors in use.
|
||||
int total_input_tensors = frame->total_input_tensors;
|
||||
const int total_input_tensors = frame->total_input_tensors;
|
||||
size_t total_bytes = 0;
|
||||
for (int i = 0; i < total_input_tensors; ++i) {
|
||||
const Entry& input = iteration->input_tensors[i];
|
||||
@ -2291,7 +2292,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
// First, propagate dead_exits (if any) to the parent frame.
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
int64 parent_iter = frame->parent_iter;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
if (parent_frame != nullptr) {
|
||||
mutex_lock paranet_frame_lock(parent_frame->mu);
|
||||
// Propagate all the dead exits to the parent frame.
|
||||
@ -2300,7 +2301,8 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
for (const Edge* e : node->out_edges()) {
|
||||
const Node* dst_node = e->dst();
|
||||
|
||||
auto dst_pending_id = impl_->gview_.node(dst_node->id())->pending_id;
|
||||
const auto dst_pending_id =
|
||||
impl_->gview_.node(dst_node->id())->pending_id;
|
||||
|
||||
// TODO(yuanbyu): We don't need this if we require the subgraph
|
||||
// given to an executor not to contain a sink node.
|
||||
@ -2358,7 +2360,7 @@ void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
}
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
int64 parent_iter = frame->parent_iter;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
DeleteFrame(frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
@ -2433,7 +2435,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bool increment_dead =
|
||||
const bool increment_dead =
|
||||
(is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
|
||||
int pending, dead;
|
||||
iter_state->adjust_for_activation(dst_pending_id, increment_dead,
|
||||
@ -2497,7 +2499,7 @@ void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
|
||||
inv_values.push_back({item->node, entry});
|
||||
|
||||
// Make this value available to all iterations.
|
||||
bool is_dead = !entry.has_value;
|
||||
const bool is_dead = !entry.has_value;
|
||||
for (int i = 0; i <= iteration_count; ++i) {
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, i, &outputs, ready);
|
||||
@ -2522,7 +2524,7 @@ bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
|
||||
void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
|
||||
TaggedNodeSeq* ready) {
|
||||
iteration_count++;
|
||||
int64 next_iter = iteration_count;
|
||||
const int64 next_iter = iteration_count;
|
||||
|
||||
// Initialize the next iteration.
|
||||
IterationState* iter_state =
|
||||
@ -2567,7 +2569,7 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
||||
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
|
||||
Executor** executor) {
|
||||
ExecutorImpl* impl = new ExecutorImpl(params, graph);
|
||||
Status s = impl->Initialize();
|
||||
const Status s = impl->Initialize();
|
||||
if (s.ok()) {
|
||||
*executor = impl;
|
||||
} else {
|
||||
@ -2579,7 +2581,7 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
|
||||
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
|
||||
const NodeDef& ndef, int graph_def_version,
|
||||
OpKernel** kernel) {
|
||||
auto device_type = DeviceType(device->attributes().device_type());
|
||||
const auto device_type = DeviceType(device->attributes().device_type());
|
||||
auto allocator = device->GetAllocator(AllocatorAttributes());
|
||||
return CreateOpKernel(device_type, device, allocator, flib, ndef,
|
||||
graph_def_version, kernel);
|
||||
|
@ -665,7 +665,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "matrix_band_part_op",
|
||||
prefix = "matrix_band_part_op",
|
||||
deps = ARRAY_DEPS,
|
||||
deps = if_cuda([
|
||||
":cuda_solvers",
|
||||
]) + ARRAY_DEPS,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -1332,7 +1334,7 @@ tf_kernel_library(
|
||||
"transpose_functor_gpu.cu.cc",
|
||||
"transpose_functor.h",
|
||||
],
|
||||
visibility = ["//visibility:private"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":ops_util",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -76,18 +76,19 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void MatrixBandPart<GPUDevice, T>::Compute( \
|
||||
const GPUDevice& d, Eigen::DenseIndex num_lower, \
|
||||
Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
extern template struct MatrixBandPart<GPUDevice, T>;
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, bool transpose, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <class Scalar>
|
||||
@ -131,9 +132,9 @@ class CholeskyOpGpu : public AsyncOpKernel {
|
||||
// before we launch each of the Cholesky factorization kernels in paralle.
|
||||
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
|
||||
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
|
||||
functor::MatrixBandPart<GPUDevice, Scalar>::Compute(
|
||||
context->eigen_device<GPUDevice>(), n, 0, input_reshaped,
|
||||
output_reshaped);
|
||||
functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
|
||||
fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
|
||||
input_reshaped, output_reshaped);
|
||||
|
||||
// Launch a Cholesky kernel for each matrix in the batch.
|
||||
const int64 batch_size = input_reshaped.dimension(0);
|
||||
|
@ -1024,7 +1024,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
|
||||
assert(__popc(kWidth) == 1);
|
||||
int sub_warp = cub::LaneId() / kWidth;
|
||||
int zeros = sub_warp * kWidth;
|
||||
unsigned mask = ((1U << kWidth) - 1) << zeros;
|
||||
unsigned mask = ((1UL << kWidth) - 1) << zeros;
|
||||
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
|
||||
val += CudaShuffleXor(mask, val, delta);
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
@ -32,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -48,18 +50,6 @@ class MatrixBandPartOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& num_lower_in = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
|
||||
errors::InvalidArgument("num_lower must be scalar, got shape ",
|
||||
num_lower_in.shape().DebugString()));
|
||||
const int64 num_lower = num_lower_in.scalar<int64>()();
|
||||
|
||||
const Tensor& num_upper_in = context->input(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
|
||||
errors::InvalidArgument("num_upper must be scalar, got shape ",
|
||||
num_upper_in.shape().DebugString()));
|
||||
const int64 num_upper = num_upper_in.scalar<int64>()();
|
||||
|
||||
const TensorShape& input_shape = input.shape();
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
@ -67,12 +57,43 @@ class MatrixBandPartOp : public OpKernel {
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
auto input_reshaped = input.flat_inner_dims<T, 3>();
|
||||
|
||||
const Tensor& num_lower_in = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
|
||||
errors::InvalidArgument("num_lower must be scalar, got shape ",
|
||||
num_lower_in.shape().DebugString()));
|
||||
const int64 num_lower = num_lower_in.scalar<int64>()();
|
||||
OP_REQUIRES(
|
||||
context, num_lower <= input_reshaped.dimension(1),
|
||||
errors::InvalidArgument(
|
||||
"num_lower must be negative or less or equal to number of rows (",
|
||||
input_reshaped.dimension(1), ") got: ", num_lower));
|
||||
|
||||
const Tensor& num_upper_in = context->input(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
|
||||
errors::InvalidArgument("num_upper must be scalar, got shape ",
|
||||
num_upper_in.shape().DebugString()));
|
||||
const int64 num_upper = num_upper_in.scalar<int64>()();
|
||||
OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
|
||||
errors::InvalidArgument("num_upper must be negative or less or "
|
||||
"equal to number of columns (",
|
||||
input_reshaped.dimension(2),
|
||||
") got: ", num_upper));
|
||||
|
||||
if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
|
||||
(num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
|
||||
// This is a no-op.
|
||||
context->set_output(0, input);
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, input_shape, &output));
|
||||
auto output_reshaped = output->flat_inner_dims<T, 3>();
|
||||
functor::MatrixBandPart<Device, T>::Compute(
|
||||
context->eigen_device<Device>(), num_lower, num_upper, input_reshaped,
|
||||
output_reshaped);
|
||||
functor::MatrixBandPartFunctor<Device, T> fn;
|
||||
fn(context, context->eigen_device<Device>(), num_lower, num_upper,
|
||||
false /* transpose */, input_reshaped, output_reshaped);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
|
||||
|
||||
// Implementation of the functor specialization for CPU.
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct MatrixBandPart<CPUDevice, T> {
|
||||
static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
typename TTypes<T, 3>::Tensor output) {
|
||||
if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
|
||||
(num_upper < 0 || num_upper >= input.dimension(2))) {
|
||||
output.device(d) = input;
|
||||
} else {
|
||||
output.device(d) = output.constant(T());
|
||||
for (int64 r = 0; r < output.dimension(0); ++r) {
|
||||
for (int64 i = 0; i < output.dimension(1); ++i) {
|
||||
const int64 band_start =
|
||||
num_lower < 0 ? 0 : std::max(0ll, i - num_lower);
|
||||
const int64 band_end =
|
||||
num_upper < 0 ? output.dimension(2)
|
||||
: std::min(static_cast<int64>(output.dimension(2)),
|
||||
i + num_upper + 1);
|
||||
if (band_start < band_end) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
|
||||
1, 1, band_end - band_start);
|
||||
output.slice(indices, sizes) = input.slice(indices, sizes);
|
||||
|
||||
// CPU implementation of BandPartFunctor.
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename Scalar>
|
||||
struct MatrixBandPartFunctor<CPUDevice, Scalar> {
|
||||
void operator()(OpKernelContext* context, const CPUDevice& device,
|
||||
int num_lower_diags, int num_upper_diags, bool transpose,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output) {
|
||||
const int64 b = input.dimension(0);
|
||||
const int64 m = input.dimension(1);
|
||||
const int64 n = input.dimension(2);
|
||||
auto thread_pool =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
const int64 total_rows = b * m;
|
||||
const int64 row_cost = 10 * n;
|
||||
const bool in_place = input.data() == output.data();
|
||||
CHECK(!(transpose && in_place));
|
||||
if (!transpose) {
|
||||
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
|
||||
if (!in_place) {
|
||||
std::fill(output.data() + begin * n, output.data() + end * n,
|
||||
Scalar());
|
||||
}
|
||||
const int64 batch_begin = begin / m;
|
||||
const int64 batch_end = (end + m - 1) / m;
|
||||
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
|
||||
const int64 row_begin = begin > batch * m ? begin % m : 0;
|
||||
const int64 row_end = end < (batch + 1) * m ? end % m : m;
|
||||
for (int64 row = row_begin; row < row_end; ++row) {
|
||||
const int64 band_start =
|
||||
num_lower_diags < 0
|
||||
? 0
|
||||
: std::min(n, std::max(0ll, row - num_lower_diags));
|
||||
const int64 band_end = num_upper_diags < 0
|
||||
? n
|
||||
: std::min(static_cast<int64>(n),
|
||||
row + num_upper_diags + 1);
|
||||
if (in_place) {
|
||||
if (band_start > 0) {
|
||||
std::fill(&output(batch, row, 0),
|
||||
&output(batch, row, band_start), Scalar());
|
||||
}
|
||||
if (band_end < n) {
|
||||
std::fill(&output(batch, row, band_end), &output(batch, row, n),
|
||||
Scalar());
|
||||
}
|
||||
} else {
|
||||
if (band_start < band_end) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
|
||||
band_start);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
|
||||
1, 1, band_end - band_start);
|
||||
output.slice(indices, sizes) = input.slice(indices, sizes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
|
||||
} else {
|
||||
output.device(device) = output.constant(Scalar());
|
||||
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
|
||||
const int64 batch_begin = begin / m;
|
||||
const int64 batch_end = (end + m - 1) / m;
|
||||
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
|
||||
const int64 row_begin = begin > batch * m ? begin % m : 0;
|
||||
const int64 row_end = end < (batch + 1) * m ? end % m : m;
|
||||
for (int64 row = row_begin; row < row_end; ++row) {
|
||||
const int64 band_start =
|
||||
num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
|
||||
const int64 band_end = num_upper_diags < 0
|
||||
? n
|
||||
: std::min(static_cast<int64>(n),
|
||||
row + num_upper_diags + 1);
|
||||
for (int64 col = band_start; col < band_end; ++col) {
|
||||
output(batch, col, row) = input(batch, row, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
|
||||
TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
|
||||
#undef DEFINE_CPU_SPEC
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void MatrixBandPart<GPUDevice, T>::Compute( \
|
||||
const GPUDevice& d, Eigen::DenseIndex num_lower, \
|
||||
Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
extern template struct MatrixBandPart<GPUDevice, T>;
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, bool transpose, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
|
@ -16,61 +16,22 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
|
||||
#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
|
||||
|
||||
// Generator definition for MatrixBandPartOp, must be compilable by nvcc.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace generator {
|
||||
|
||||
template <typename T>
|
||||
class MatrixBandPartGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE MatrixBandPartGenerator(
|
||||
Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
|
||||
typename TTypes<T, 3>::ConstTensor input)
|
||||
: num_lower_(num_lower), num_upper_(num_upper), input_(input) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 3>& coords) const {
|
||||
return (((num_lower_ < 0 || coords[1] - coords[2] <= num_lower_) &&
|
||||
(num_upper_ < 0 || coords[2] - coords[1] <= num_upper_))
|
||||
? input_(coords)
|
||||
: T());
|
||||
}
|
||||
|
||||
private:
|
||||
const Eigen::DenseIndex num_lower_;
|
||||
const Eigen::DenseIndex num_upper_;
|
||||
typename TTypes<T, 3>::ConstTensor input_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct MatrixBandPart {
|
||||
EIGEN_ALWAYS_INLINE static void Compute(
|
||||
const Device& d, Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
typename TTypes<T, 3>::Tensor output) {
|
||||
if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
|
||||
(num_upper < 0 || num_upper >= input.dimension(2))) {
|
||||
output.device(d) = input;
|
||||
} else {
|
||||
generator::MatrixBandPartGenerator<T> generator(num_lower, num_upper,
|
||||
input);
|
||||
output.device(d) = output.generate(generator);
|
||||
}
|
||||
}
|
||||
template <typename Device, typename Scalar>
|
||||
struct MatrixBandPartFunctor {
|
||||
void operator()(OpKernelContext* context, const Device& device,
|
||||
int num_upper_diags, int num_lower_diags, bool transpose,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
|
||||
|
@ -17,22 +17,92 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include <complex>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#define DEFINE_GPU_SPEC(T) \
|
||||
template class generator::MatrixBandPartGenerator<T>; \
|
||||
template struct functor::MatrixBandPart<GPUDevice, T>;
|
||||
template <bool transpose, typename Scalar>
|
||||
__global__ void MatrixBandPartKernel(const int num_threads,
|
||||
const int batch_size, const int m,
|
||||
const int n, const int num_lower_diags,
|
||||
const int num_upper_diags,
|
||||
const Scalar* input_ptr,
|
||||
Scalar* output_ptr) {
|
||||
if (!transpose) {
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int col = index % n;
|
||||
const int row = (index / n) % m;
|
||||
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
|
||||
const int band_end =
|
||||
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
|
||||
if (col < band_start || col >= band_end) {
|
||||
output_ptr[index] = Scalar();
|
||||
} else {
|
||||
output_ptr[index] = input_ptr[index];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int matrix_size = m * n;
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int col = index % n;
|
||||
const int row = (index / n) % m;
|
||||
const int batch = index / matrix_size;
|
||||
const int transpose_index = batch * matrix_size + n * col + row;
|
||||
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
|
||||
const int band_end =
|
||||
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
|
||||
if (col < band_start || col >= band_end) {
|
||||
output_ptr[transpose_index] = Scalar();
|
||||
} else {
|
||||
output_ptr[transpose_index] = input_ptr[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
struct MatrixBandPartFunctor<GPUDevice, Scalar> {
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device,
|
||||
int num_lower_diags, int num_upper_diags, bool transpose,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output) {
|
||||
using CudaType = typename CUDAComplexT<Scalar>::type;
|
||||
const int batch_size = input.dimension(0);
|
||||
const int m = input.dimension(1);
|
||||
const int n = input.dimension(2);
|
||||
const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
|
||||
CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
|
||||
if (transpose) {
|
||||
MatrixBandPartKernel<true>
|
||||
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
|
||||
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
|
||||
num_upper_diags, input_ptr, output_ptr);
|
||||
} else {
|
||||
MatrixBandPartKernel<false>
|
||||
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
|
||||
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
|
||||
num_upper_diags, input_ptr, output_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
} // end namespace tensorflow
|
||||
#undef DEFINE_GPU_SPEC
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -6898,6 +6898,41 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DecodeRaw"
|
||||
input_arg {
|
||||
name: "bytes"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "out_type"
|
||||
}
|
||||
attr {
|
||||
name: "out_type"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_UINT16
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "little_endian"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DecodeWav"
|
||||
input_arg {
|
||||
|
@ -399,7 +399,7 @@ matrix is assumed to be zero and not accessed.
|
||||
`rhs` is a tensor of shape `[..., M, K]`.
|
||||
|
||||
The output is a tensor of shape `[..., M, K]`. If `adjoint` is
|
||||
`True` then the innermost matrices in output` satisfy matrix equations
|
||||
`True` then the innermost matrices in `output` satisfy matrix equations
|
||||
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
|
||||
If `adjoint` is `False` then the strictly then the innermost matrices in
|
||||
`output` satisfy matrix equations
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user