Branch 152703253 (#9112)
* Improve py_func error handling. Automatically translate some python errors into corresponding TF errors at runtime. Change: 152156821 * Update interaction with libpng so that we use the public API instead of knowledge of the internal libpng data structures. Change: 152167754 * TensorBoard plugins now contain their own name/route prefix. Change: 152167807 * Passes trainable flag to separable_conv2d biases. Change: 152170239 * Saving resource variables with a caching device. Change: 152171539 * Drop loss from estimator_spec.eval_metric_ops, as required by core Estimator. Change: 152179924 * sample_stats.percentile DOCFIX. Change: 152182295 * Added a memory optimizer to grappler. Change: 152184170 * Change default behavior of the tf runs selector: - If there are fewer than 41 runs, enable them all by default - If there are 41 runs or more, disable them all by default This is in response to user complaints that having it enable only the first ten runs by default was confusing, because it was not obvious to users that some runs had been disabled. However, it still solves the initial user complaint that having very many runs simultaneously enabled would lag the UI. I also changed the "toggle all runs" button to try to turn everything off before turning everything on. Also, I improved the logic for detecting when the runs selection is back in the default state, so that we can avoid generating long URI strings wherever possible. Change: 152188948 * Autogenerated Change: Change TensorBoard TAG to 52 Change: 152189000 * Remove warning that only happening with config cuda. Change: 152189205 * Make resource variable shared name consistent with non-resource variables. Remove colocation constraint from resource variable cached value with the variable itself. Change: 152192203 * Add a way to specify the optimization order; refactor and add constant folding to meta optimizer. Change: 152193646 * Backport fixes and improvements from external Keras. Change: 152198296 * Merge changes from github. Change: 152200430 * Go: Update generated wrapper functions for TensorFlow ops. Change: 152200754 * Update ops-related pbtxt files. Change: 152203174 * Make ImportGraphDef() work with functions. In addition to modify graph_constructor.cc, this patch adds some other functionality to enable importing fucntions: * Ability to add FunctionDefLibraries to Graphs and FunctionLibraryDefinitions (in addition to existing functions) * FunctionDefsEqual() utility function Change: 152205258 * Expand contrib test to more than just test targets. Change: 152206822 * Preserve graph version during optimization Change: 152213262 * Exclude enter and exit nodes from shape refiner's constant folding. Change: 152213637 * Allow reshape_mover and algebraic_simplifier to make multiple mutations, by avoiding the short-circuit std::any_of. Change: 152232810 * Fix dynamic_rnn transpose bug (can input/output non-3d tensors). Also a few cleanups to RNN code. Change: 152267628 * Fix flaky tests Change: 152272801 * Add an auto parallelization grappler optimization pass. Change: 152276787 * Change json.decode.JSONDecodeError to ValueError. JSONDecodeError seems to be the exception used in the simplejson module, not the json module. Change: 152278012 * Internal change. Change: 152281471 * [XLA] Force buffer sharing of separate while instructions. Change: 152288540 * replica_device_setter should work for resource variables Change: 152289915 * Fix ./configure script 1. Add %workspace% in .bazelrc file when using import statement 2. Write action_env into bazelrc file for required environment variables for OpenCL support Change: 152290700 * Pointing a number of Tensorboard graph visualization-related help links to the new locations for the correspondent API documentation. Change: 152293459 * Restore most of pull request #8606 Pull request #8606 added str(Label(...)) for most dependencies in tensorflow.bzl, allowing most functions to be used from repositories which include TensorFlow as a submodule. Unfortunately, it broke when pulled into Google and was removed in cl/152200430. This CL restores the change, except for two Android-only functions; these were the only problematic bits. Change: 152297413 * Removed dead code in Estimator. Change: 152297597 * Assert rank is at least equal to new_rank for `_sparse_inner_flatten`. Change: 152303319 * Extend quantization ranges to include 0.0f. Change: 152304380 * Remove Keras config file saving. Change: 152306552 * API backwards compatibility tests. Change: 152310869 * [TF:XLA] Add a test for an R3 -> R4 broadcast. Change: 152313967 * Fix the problem that no enough placeholders for persistent tensor batch delete The deleter_key is always a device_name, hence there is only one of it. Hence, we cannot delete >1 handles at one time. In the fix, it creates delete placeholder on demand, the max number of placeholders is _DEAD_HANDLES_THRESHOLD. Change: 152322770 * [XLA] Add several reduction tests. Change: 152323510 * Added the memory optimizer to the meta optimizer. Change: 152323689 * Started a set of utilities to categorize op types Change: 152329057 * Add AudioSpectrogram op to TensorFlow for audio feature generation Change: 152332221 * Update ops-related pbtxt files. Change: 152332812 * Automated rollback of change 152332221 Change: 152333917 * Call Py_CLEAR on dead fields during TF_RESOURCE-to-ndarray conversion Change: 152338333 * [TF contrib seq2seq] Initial, incomplete implementation of beam search decoder. **DOES NOT WORK, pushed for collaboration only** Change: 152343927 * [XLA] Change HloPassPipeline to disallow Add* calls after Run. Change: 152345578 * Automated rollback of change 152332812 Change: 152349057 * Remove all 64/32 bit compiler warnings from core/ops. Change: 152353506 * libtensorflow.so: Don't export private symbols. With this change, libtensorflow.so will only export functions defined in c_api.h. This also results in a decreased binary size of libtensorflow.so. On Linux the decrease was from roughly 150MB to 67MB. On OS X it was from roughly 101MB to 82MB. Also fixes #8923 Change: 152366053 * Add Elu ops in XLA. Change: 152383201 * Fixed test. ('broadcast_dims' has size 1) Change: 152383633 * Add more detailed error message for rank assertion in _sparse_inner_flatten. Change: 152397909 * tensor_bundle: propagrates errors related to directory creation. Change: 152401909 * matrix_adjoint added to contrib/linalg/linear_operator_util Change: 152404828 * Add an is_active method to plugins This method determines whether a plugin is active. A plugin may be inactive if say it lacks data. This new is_active method allows us to add a route to TensorBoard noting which plugins are active. The frontend could then avoid querying routes of inactive plugins. Change: 152406232 * Replace a gather op for shapes by a stack op so dilated convolutions can be placed on GPU even with strict placing (before the gather went to CPU). Change: 152411159 * [TF:XLA] Implement BatchToSpace, BatchToSpaceND, SpaceToBatch, SpaceToBatchND. Fix crashes in core implementations of the same operators for zero-sized blocks. Change: 152416903 * Estimator saves relative paths in checkpoint. Change: 152420211 * Fix layers_test exception regex matching. Change: 152422855 * Unhide bijectors. Correct TransformedDistribution docstring. Change: 152424418 * Choosing a saner default for min_eval_frequency in the constructor for Experiment for the GCS file system, because the default of 1 causes performance problems. Change: 152439984 * Inherit use_resource from scope for partitioned variables. Change: 152442103 * Support quantized reshape in hexagon runtime Change: 152445539 * tfdbg CLI: add command list_source (ls) + UI fixes and improvements The new list_source (shorthand: ls) command lists Python source files responsible for constructing the nodes and tensors encountered in the run() call. It divides the source files into two categories and list them separately. 1) files that are not part of the TensorFlow Python library, and 2) files that are a part of it. The list contains information about how many nodes, tensors and dumps of tensors the files is responsible for. The file paths contain clickable links to the existing print_source/ps command. The list_source/ls command supports filtering by file-path and node-name regex patterns. UI fixes: * Fixed inconsistent black vs. transparent background color that made the layout look messy on some terminal types. Now using the transparent color for default font color consistently. * In the print_source command output, add clickable links to expand source lines and graph elements. Change: 152446002 * tfcompile: Be a little more verbose about missing required flags. Fixes #9014 Change: 152446338 * Disable failing test cases in pooling_ops_test. Change: 152447322 * Register more types for tf.image_crop_and_resize(). Resolves #9020. Change: 152448160 * Automated rollback of change 152439984 Change: 152450929 * Add a route to TensorBoard for fetching plugin names Specifically, we add a /data/plugins_listing route to the TensorBoard application. This route responds with an object mapping the name of each initialized plugin to whether it is active. This route could help the frontend avoid issuing requests to inactive plugins. Ordered the listing of routes within application.py so there is a little more organization. Refactored the test for application to use a fake plugin. Change: 152451390 * Added the ability to retrieve the amount of usable gpu memory Change: 152453470 * Allow to set session ConfigProto in RunConfig and use it in Estimator. Change: 152454548 * Colocate ResourceVariable reads with their handles. Change: 152455939 * tfdbg: update doc for new command list_source/ls Change: 152456128 * Make rnn directions slightly easier to follow. Change: 152456296 * Internal change Change: 152458104 * Adds batch renormalization. NOTE: if you use renormalization, you might want to use faster moving average updates, i.e. lower `decay` values. Change: 152458872 * When using ImportGraphDef with a passed in ShapeRefiner, use the producer version of the GraphDef when importing; the ShapeRefiner may be initialized with a different graph_def_version, so we need to be able to override it. The test failed without the change to graph_constructor and passes with it. The test uses a legacy graph that is supported (reduction shape). Change: 152459169 * Allow any iterable for `export_strategies` arg. Change: 152461826 * Log steps/sec every 100 steps in MonitoredSession, as before. Change: 152465320 * Fixes documentation to note that the in case of ties the identity of the return value of ArgMin and ArgMaxis not guaranteed . Change: 152465346 * Automated rollback of change 152465346 Change: 152465844 * Fix shape inference fn on _ParallelConcatStart. Change: 152466076 * Fix getting started guide Explain numerical differences in loss fix one example to print Change: 152466119 * Remove superfluous mode argument. Change: 152467334 * Add a tool that converts HLO computations to tensorflow GraphDef which can be visualized on Tensorboard. This CL defines basic tensorflow::OpDef for each HLO instruction/node. More attributes (e.g. shapes, colors) will be added in the future. Change: 152477918 * [TF:XLA] Increase shard count of //third_party/tensorflow/compiler/tests:spacetobatch_test to reduce flakiness when built under ASAN. Change: 152496244 * Make projector plugin backend read assets saved via the PluginAssets API. At the same time, keep backwards compatibility with the old way of looking up assets. Change: 152504793 * Move MNIST pointers to mirror hosted by the CVDF on Google Cloud. Fixes: #9031 Change: 152504901 * Merge changes from github. Change: 152508170 * Update API after changing default step couter frequency before. Change: 152517535 * Move a few random op helper functions to header files 1. shape_inference::RandomShape 2. OpKernel::MakeShape(Tensor, TensorShape*) Change: 152522156 * addresses the divide by zero bug Change: 152522488 * Clarify doc on tf.assign. Change: 152523909 * Sparse adam for resource variables. Change: 152525327 * Automated rollback of change 152310869 Change: 152528732 * Add an env_var tf_sync_on_finish_bool that block until device has finished all queued operations in a step if true. Change: 152533676 * Add more node attributes for HloInstruction on Tensorboard e.g. shape and layout etc. Change: 152534472 * Add tf.complex64 GPU support to tf.gather. Also add ldg specializations for std::complex. Change: 152537848 * Formatting changes Change: 152544842 * Upgrade TensorBoard TypeScript to 2.2.1 See also: #8326 Change: 152545950 * TEST: Getting reasonable test sizes on linalg library, removing need for sharding. Change: 152546409 * Disabling _testSourceUtilModuleReturnsTrue as its causing opensource issues. Change: 152548721 * Fix race due to unsafe buffer forwarding in maxpooling second order gradients added in #6664. Re-enable previously flaky tests. Clean up a few minor things in maxpooling_op_gpu.cu.cc Change: 152550050 * LinearOperator: adjoint_arg kwarg added to all operators. Now, operator.apply(x, adjoint_arg=True) means that the adjoint of 'x' is taken before application of operator. Sometimes this is done more efficiently than simply taking adjoint. Change: 152560471 * Adds weighted_average_loss metric key. Change: 152560999 * Documentation: Fix bug in manual device placement example Change: 152563392 * Change for internal compatibility. * Use std::vector for storage instead of map. Do the sorting inplace and return the same vector to avoid any copies. On larger streams it is about 50% faster. Change: 152576112 * Add tf.add_n GPU support for complex64/complex128. Also adds a unit test for tf.add_n. Change: 152577190 * - Adds support for nested types in tf.case and tf.cond. - Adds a "strict" mode which disables silent unpacking of singleton lists. - Adds shape inference to tf.case. - Adds a lot of unit tests. Change: 152581097 * [XLA] Add support for folding transpose into convolution Change: 152581336 * Add a smoke test to ensure that the doc generator runs. Change: 152592164 * Add tensorboard to the _do_not_descend_map of the PublicAPIVisitor. Change: 152592268 * Add auto parallelization to meta optimizer. Enable MetaOptimizer if any one of the optimizers is on. Change: 152598517 * Update ops-related pbtxt files. Change: 152629248 * Prevent the renorm_weight from being updated too early. Change: 152631776 * Automated rollback of change 152528732 Change: 152652473 * Construct TensorBoard dashboards in a JS list Previously, adding a dashboard to TensorBoard involved changing logic in several places. As part of this effort, added constructors to dashboards. Tweaked logic in various dashboards to preserve original behavior. For instance, the graph dashboard can only perform fitting after the dashboard is attached to the DOM. Change: 152658532 * Make CheckpointSaverListener visible next to CheckpointSaverHook. Change: 152662945 * tfdbg CLI: minor bug fixes 1: The calculation of the scroll command in the scroll bar didn't take into account that the y-coordinate of the scroll block is in the ScrollBar coordinate system, while the mouse click y-coordinate is in the screen coordinate system. 2: The y position of the ScrollBar was off by one. 3: The command box is not re-created after mouse-triggered commands, leading to strange-looking cursor position. Change: 152684294 * Remove obsolete use of validate_indices from embedding_ops.py validate_indices is ignored, so it shouldn't appear in new code. Change: 152691948 * Preparation of using GMock matchers in XLA tests. Change: 152691970 * Replace RuntimeException by RuntimeError in coordinator documentation. Change: 152697758 * Move the TensorBoard debugger plugin to be internal. This feature is currently not open-source anyway. Change: 152700267 * Add a single-machine tf.learn Estimator implementation for the WALS solver. Change: 152700915 * Add tf.contrib.training.python_input -- making it easy to feed data into TensorFlow from python coroutines. Change: 152701623 * Show that QuantizeToFloat consistently introduces a small error. The error is equal to range_min - round(range_min / range_scale) * range_scale Change: 152702015 * Internal Changes Change: 152703253 * Remove tensorflow/tensorboard/plugins/debugger, as part of merge resolution.
This commit is contained in:
parent
8682f1f878
commit
9e7bf40381
@ -308,10 +308,12 @@ filegroup(
|
||||
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
|
||||
"//tensorflow/tensorboard/lib:all_files",
|
||||
"//tensorflow/tensorboard/plugins:all_files",
|
||||
"//tensorflow/tensorboard/plugins/debugger:all_files",
|
||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||
"//tensorflow/tensorboard/plugins/text:all_files",
|
||||
"//tensorflow/tensorboard/scripts:all_files",
|
||||
"//tensorflow/tools/api/golden:all_files",
|
||||
"//tensorflow/tools/api/lib:all_files",
|
||||
"//tensorflow/tools/api/tests:all_files",
|
||||
"//tensorflow/tools/common:all_files",
|
||||
"//tensorflow/tools/compatibility:all_files",
|
||||
"//tensorflow/tools/dist_test/server:all_files",
|
||||
@ -346,6 +348,11 @@ filegroup(
|
||||
),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "docs_src",
|
||||
data = glob(["docs_src/**/*.md"]),
|
||||
)
|
||||
|
||||
# -------------------------------------------
|
||||
# New rules should be added above this target.
|
||||
# -------------------------------------------
|
||||
|
@ -44,6 +44,17 @@ xla_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test",
|
||||
testonly = 1,
|
||||
hdrs = ["test.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "types",
|
||||
hdrs = ["types.h"],
|
||||
@ -256,10 +267,9 @@ cc_test(
|
||||
":array4d",
|
||||
":literal_util",
|
||||
":shape_util",
|
||||
":test_helpers",
|
||||
":test",
|
||||
":types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -21,14 +21,16 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class LiteralUtilTest : public ::testing::Test {
|
||||
protected:
|
||||
LiteralUtilTest() {
|
||||
@ -159,9 +161,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
|
||||
// clang-format on
|
||||
|
||||
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
|
||||
EXPECT_MATCH(testing::PBToVec<tensorflow::protobuf_int64>(
|
||||
literal->shape().dimensions()),
|
||||
testing::VectorMatcher<tensorflow::protobuf_int64>({2, 3, 2}));
|
||||
EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
|
||||
string result = LiteralUtil::ToString(*literal);
|
||||
const string expected = R"(f32[2,3,2] {
|
||||
{ { 1, 2 },
|
||||
@ -182,9 +182,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
{2001, 2002},
|
||||
}, /*projection_p=*/1, /*projection_z=*/2);
|
||||
// clang-format on
|
||||
EXPECT_MATCH(
|
||||
testing::PBToVec(literal->shape().dimensions()),
|
||||
testing::VectorMatcher<tensorflow::protobuf_int64>({1, 2, 3, 2}));
|
||||
EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
|
||||
string result = LiteralUtil::ToString(*literal);
|
||||
const string expected = R"(f32[1,2,3,2] {
|
||||
{ // i0=0
|
||||
@ -204,10 +202,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
|
||||
EXPECT_MATCH(
|
||||
testing::PBToVec<tensorflow::protobuf_int64>(
|
||||
literal_r4_2x2x3x3_dim0major_->shape().dimensions()),
|
||||
testing::VectorMatcher<tensorflow::protobuf_int64>({2, 2, 3, 3}));
|
||||
EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
|
||||
ElementsAre(2, 2, 3, 3));
|
||||
string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_);
|
||||
const string expected = R"(f32[2,2,3,3] {
|
||||
{ // i0=0
|
||||
@ -516,27 +512,23 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) {
|
||||
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int>(
|
||||
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
|
||||
EXPECT_EQ(mat_dim0minor->s32s_size(), 6);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(mat_dim0minor->s32s()),
|
||||
testing::VectorMatcher<int32>({1, 4, 2, 5, 3, 6}));
|
||||
EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
|
||||
|
||||
// Test expected memory layout when using Relayout to row major.
|
||||
auto relaid_mat_to_dim0major =
|
||||
LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(relaid_mat_to_dim0major->s32s()),
|
||||
testing::VectorMatcher<int32>({1, 2, 3, 4, 5, 6}));
|
||||
EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
|
||||
|
||||
// Test expected memory layout of R2 created with dim0-major (row-major).
|
||||
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int>(
|
||||
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
|
||||
EXPECT_EQ(mat_dim0major->s32s_size(), 6);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(mat_dim0major->s32s()),
|
||||
testing::VectorMatcher<int32>({1, 2, 3, 4, 5, 6}));
|
||||
EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
|
||||
|
||||
// Test expected memory layout when using Relayout to column major.
|
||||
auto relaid_mat_to_dim0minor =
|
||||
LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(relaid_mat_to_dim0minor->s32s()),
|
||||
testing::VectorMatcher<int32>({1, 4, 2, 5, 3, 6}));
|
||||
EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, TestR3LinearLayout) {
|
||||
@ -558,28 +550,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
|
||||
|
||||
EXPECT_EQ(lit_dim0minor->s32s_size(), 12);
|
||||
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(lit_dim0minor->s32s()),
|
||||
testing::VectorMatcher<int32>(expected_dim0minor));
|
||||
EXPECT_THAT(lit_dim0minor->s32s(),
|
||||
testing::ElementsAreArray(expected_dim0minor));
|
||||
|
||||
// Test expected memory layout when using Relayout to row major.
|
||||
auto relaid_lit_to_dim0major =
|
||||
LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_);
|
||||
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(relaid_lit_to_dim0major->s32s()),
|
||||
testing::VectorMatcher<int32>(expected_dim0major));
|
||||
EXPECT_THAT(relaid_lit_to_dim0major->s32s(),
|
||||
testing::ElementsAreArray(expected_dim0major));
|
||||
|
||||
// Test expected memory layout of R3 created with dim0-major (row-major).
|
||||
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
|
||||
arr3d, layout_r3_dim0major_);
|
||||
EXPECT_EQ(lit_dim0major->s32s_size(), 12);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(lit_dim0major->s32s()),
|
||||
testing::VectorMatcher<int32>(expected_dim0major));
|
||||
EXPECT_THAT(lit_dim0major->s32s(),
|
||||
testing::ElementsAreArray(expected_dim0major));
|
||||
|
||||
// Test expected memory layout when using Relayout to column major.
|
||||
auto relaid_lit_to_dim0minor =
|
||||
LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_);
|
||||
EXPECT_MATCH(testing::PBToVec<int32>(relaid_lit_to_dim0minor->s32s()),
|
||||
testing::VectorMatcher<int32>(expected_dim0minor));
|
||||
EXPECT_THAT(relaid_lit_to_dim0minor->s32s(),
|
||||
testing::ElementsAreArray(expected_dim0minor));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, SliceR0S32) {
|
||||
|
@ -1431,7 +1431,9 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -1442,11 +1444,13 @@ cc_test(
|
||||
srcs = ["transpose_folding_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":shape_inference",
|
||||
":transpose_folding",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -232,7 +232,14 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
|
||||
pass.AddPass<ReshapeMover>();
|
||||
pass.AddPass<HloConstantFolding>();
|
||||
}
|
||||
pipeline.AddPass<TransposeFolding>(PotentiallyImplementedAsEigenDot);
|
||||
pipeline.AddPass<TransposeFolding>(
|
||||
[](const HloInstruction& dot,
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
return PotentiallyImplementedAsEigenDot(dot)
|
||||
? candidate_operands
|
||||
: TransposeFolding::OperandIndices{};
|
||||
},
|
||||
TransposeFolding::NeverFoldTranspose);
|
||||
pipeline.AddPass<HloSubcomputationUnification>();
|
||||
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
|
||||
pipeline.AddPass<CpuInstructionFusion>();
|
||||
|
@ -133,7 +133,13 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
pass.AddPass<HloConstantFolding>();
|
||||
}
|
||||
pipeline.AddPass<ConvolutionFolding>();
|
||||
pipeline.AddPass<TransposeFolding>(ImplementedAsGemm);
|
||||
pipeline.AddPass<TransposeFolding>(
|
||||
[](const HloInstruction& dot,
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
return ImplementedAsGemm(dot) ? candidate_operands
|
||||
: TransposeFolding::OperandIndices{};
|
||||
},
|
||||
TransposeFolding::NeverFoldTranspose);
|
||||
pipeline.AddPass<HloSubcomputationUnification>();
|
||||
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
|
||||
pipeline.AddPass<HloDCE>();
|
||||
|
@ -25,16 +25,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
const int64 kWarpSize = 32;
|
||||
|
||||
// Precondition: "hlo" is an operand of a Dot instruction.
|
||||
//
|
||||
// Returns whether "hlo" is foldable to its user.
|
||||
bool IsOperandFoldableToDot(const HloInstruction& hlo);
|
||||
|
||||
// Returns true if GpuCompiler can fold any operands of "dot" into "dot" for
|
||||
// better performance.
|
||||
bool CanFoldOperandsIntoDot(const HloInstruction& dot);
|
||||
constexpr int64 kWarpSize = 32;
|
||||
|
||||
// Returns true if `hlo` will be implemented as a call to BLAS gemm.
|
||||
bool ImplementedAsGemm(const HloInstruction& hlo);
|
||||
|
@ -21,7 +21,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -30,43 +32,56 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsOperandFoldableToDot(const HloInstruction& hlo) {
|
||||
return hlo.IsRank2Transpose() &&
|
||||
hlo.user_count() == 1; // The dot is its only user.
|
||||
}
|
||||
|
||||
bool CanFoldOperandsIntoDot(
|
||||
TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
|
||||
const HloInstruction& dot,
|
||||
const TransposeFolding::IsTransposableGemmFn& is_transposable_gemm) {
|
||||
const TransposeFolding::TransposableGemmOperandsFn&
|
||||
transposable_gemm_operands) {
|
||||
if (HloOpcode::kDot != dot.opcode()) {
|
||||
return false;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!is_transposable_gemm(dot)) {
|
||||
return false;
|
||||
TransposeFolding::OperandIndices operand_set;
|
||||
for (int64 i = 0; i < dot.operand_count(); ++i) {
|
||||
auto& operand = *dot.operand(i);
|
||||
if (operand.IsRank2Transpose() && operand.user_count() == 1) {
|
||||
operand_set.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
const HloInstruction* lhs = dot.operand(0);
|
||||
const HloInstruction* rhs = dot.operand(1);
|
||||
bool lhs_foldable = IsOperandFoldableToDot(*lhs);
|
||||
bool rhs_foldable = IsOperandFoldableToDot(*rhs);
|
||||
if (!lhs_foldable && !rhs_foldable) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return transposable_gemm_operands(dot, operand_set);
|
||||
}
|
||||
|
||||
TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
|
||||
const HloInstruction& convolution,
|
||||
const TransposeFolding::TransposableConvOperandsFn&
|
||||
transposable_conv_operands) {
|
||||
if (HloOpcode::kConvolution != convolution.opcode()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// We only support folding the RHS.
|
||||
const int64 kRhsOperandIndex = 1;
|
||||
auto& operand = *convolution.operand(kRhsOperandIndex);
|
||||
if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) {
|
||||
return transposable_conv_operands(convolution, {kRhsOperandIndex});
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
using InstructionOperandsPair =
|
||||
std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
|
||||
|
||||
// Folds the operands of `dot` that are foldable transposes. `computation` is
|
||||
// the parent HLO computation of `dot`. `module` is the parent HloModule of
|
||||
// `computation`.
|
||||
// the parent HLO computation of `dot`.
|
||||
//
|
||||
// Returns whether the module is changed.
|
||||
bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) {
|
||||
bool FoldTransposeIntoDot(InstructionOperandsPair pair,
|
||||
HloComputation* computation) {
|
||||
auto* dot = pair.first;
|
||||
std::vector<HloInstruction*> instructions_to_fuse(1, dot);
|
||||
for (HloInstruction* operand : dot->operands()) {
|
||||
if (IsOperandFoldableToDot(*operand)) {
|
||||
instructions_to_fuse.push_back(operand);
|
||||
}
|
||||
for (const int64 operand_index : pair.second) {
|
||||
instructions_to_fuse.push_back(dot->mutable_operand(operand_index));
|
||||
}
|
||||
|
||||
// Early-exit if no operands are foldable.
|
||||
@ -79,28 +94,95 @@ bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Folds the operands of `convolution` that are foldable transposes.
|
||||
// `computation` is the parent HLO computation of `convolution`.
|
||||
//
|
||||
// Returns whether the module is changed.
|
||||
bool FoldTransposeIntoConvolution(InstructionOperandsPair pair,
|
||||
HloComputation* computation) {
|
||||
auto& convolution = *pair.first;
|
||||
|
||||
// We only support fusing the RHS transpose into convolution.
|
||||
//
|
||||
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
|
||||
// the output and the activations.
|
||||
//
|
||||
// TODO(b/37125184): Support transposing the LHS too.
|
||||
if (pair.second.size() != 1 || pair.second.front() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution.convolution_dimension_numbers();
|
||||
HloInstruction& transpose = *convolution.mutable_operand(1);
|
||||
CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
|
||||
// Everything remains the same except for the kernel dimension numbers. We
|
||||
// need to apply the transpose permutation to the original shape to figure out
|
||||
// what the new logical dimensions are.
|
||||
ConvolutionDimensionNumbers new_dnums = dnums;
|
||||
new_dnums.set_kernel_input_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
|
||||
new_dnums.set_kernel_output_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
|
||||
for (auto& kernel_spatial_dimension :
|
||||
*new_dnums.mutable_kernel_spatial_dimensions()) {
|
||||
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
|
||||
}
|
||||
|
||||
auto new_conv = HloInstruction::CreateConvolve(
|
||||
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
|
||||
convolution.window(), new_dnums);
|
||||
TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution,
|
||||
std::move(new_conv)));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm)
|
||||
: is_transposable_gemm_(std::move(is_transposable_gemm)) {}
|
||||
TransposeFolding::TransposeFolding(
|
||||
TransposableGemmOperandsFn transposable_gemm_operands,
|
||||
TransposableConvOperandsFn transposable_conv_operands)
|
||||
: transposable_gemm_operands_(std::move(transposable_gemm_operands)),
|
||||
transposable_conv_operands_(std::move(transposable_conv_operands)) {}
|
||||
|
||||
StatusOr<bool> TransposeFolding::Run(HloModule* module) {
|
||||
// Modifying the graph while traversing is dangerous, so we find all folding
|
||||
// opportunities before actually folding them.
|
||||
HloComputation* entry_computation = module->entry_computation();
|
||||
|
||||
std::vector<HloInstruction*> foldable_dots;
|
||||
auto visit_fn = [this, &foldable_dots](HloInstruction* instruction) {
|
||||
if (CanFoldOperandsIntoDot(*instruction, is_transposable_gemm_)) {
|
||||
foldable_dots.emplace_back(instruction);
|
||||
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
|
||||
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
|
||||
auto visit_fn = [this, &foldable_dots,
|
||||
&foldable_convolutions](HloInstruction* instruction) {
|
||||
{
|
||||
OperandIndices operand_indices =
|
||||
CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_);
|
||||
if (!operand_indices.empty()) {
|
||||
foldable_dots.emplace_back(instruction, operand_indices);
|
||||
}
|
||||
}
|
||||
{
|
||||
OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
|
||||
*instruction, transposable_conv_operands_);
|
||||
if (!operand_indices.empty()) {
|
||||
foldable_convolutions.emplace_back(
|
||||
std::make_pair(instruction, operand_indices));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
};
|
||||
TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn));
|
||||
|
||||
bool changed = false;
|
||||
for (HloInstruction* dot : foldable_dots) {
|
||||
changed |= FoldTransposeIntoDot(dot, entry_computation);
|
||||
for (InstructionOperandsPair& pair : foldable_dots) {
|
||||
changed |= FoldTransposeIntoDot(pair, entry_computation);
|
||||
}
|
||||
for (InstructionOperandsPair& pair : foldable_convolutions) {
|
||||
changed |= FoldTransposeIntoConvolution(pair, entry_computation);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
@ -25,16 +25,37 @@ namespace xla {
|
||||
// operator is implemented by a GEMM kernel that can transpose its inputs.
|
||||
class TransposeFolding : public HloPassInterface {
|
||||
public:
|
||||
// IsTransposableGemmFn should return true iff the instruction argument is
|
||||
// implemented as a GEMM kernel that supports transposing its arguments.
|
||||
typedef std::function<bool(const HloInstruction&)> IsTransposableGemmFn;
|
||||
explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm);
|
||||
using OperandIndices = std::vector<int64>;
|
||||
|
||||
// Returns the set of foldable operands for a given HLO and some candidate
|
||||
// operands.
|
||||
using FoldableOperands = std::function<OperandIndices(const HloInstruction&,
|
||||
const OperandIndices&)>;
|
||||
using TransposableGemmOperandsFn = FoldableOperands;
|
||||
using TransposableConvOperandsFn = FoldableOperands;
|
||||
|
||||
// Helper function to explicitly not fold transposes.
|
||||
static OperandIndices NeverFoldTranspose(const HloInstruction&,
|
||||
const OperandIndices&) {
|
||||
return {};
|
||||
}
|
||||
// transposable_gemm_operands returns the set of operands it wants to fold if
|
||||
// the instruction argument is implemented as a GEMM kernel that supports
|
||||
// transposing its arguments.
|
||||
//
|
||||
// transposable_conv_operands returns the set of operands it wants to fold if
|
||||
// the instruction argument is implemented as a convolution that supports
|
||||
// transposing its arguments.
|
||||
explicit TransposeFolding(
|
||||
TransposableGemmOperandsFn transposable_gemm_operands,
|
||||
TransposableConvOperandsFn transposable_conv_operands);
|
||||
tensorflow::StringPiece name() const override { return "transpose-folding"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
IsTransposableGemmFn is_transposable_gemm_;
|
||||
TransposableGemmOperandsFn transposable_gemm_operands_;
|
||||
TransposableConvOperandsFn transposable_conv_operands_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,15 +16,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -35,12 +37,22 @@ namespace xla {
|
||||
class TransposeFoldingTest : public ::testing::Test {
|
||||
protected:
|
||||
void FoldTranspose(HloModule* module) {
|
||||
TransposeFolding transpose_folding(gpu::ImplementedAsGemm);
|
||||
TransposeFolding transpose_folding(
|
||||
[](const HloInstruction& dot,
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
return gpu::ImplementedAsGemm(dot)
|
||||
? candidate_operands
|
||||
: TransposeFolding::OperandIndices{};
|
||||
},
|
||||
[](const HloInstruction& convolution,
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
return candidate_operands;
|
||||
});
|
||||
EXPECT_IS_OK(transpose_folding.Run(module).status());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TransposeFoldingTest, FoldTranspose) {
|
||||
TEST_F(TransposeFoldingTest, FoldDotTranspose) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
|
||||
@ -61,7 +73,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) {
|
||||
FoldTranspose(&module);
|
||||
|
||||
// Instructions after folding: x, y, and the fusion.
|
||||
std::set<HloInstruction*> instruction_set;
|
||||
std::unordered_set<HloInstruction*> instruction_set;
|
||||
for (auto& instruction : entry_computation->instructions()) {
|
||||
instruction_set.insert(instruction.get());
|
||||
}
|
||||
@ -77,7 +89,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) {
|
||||
EXPECT_EQ(4, fusion->fused_instructions().size());
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, FoldTransposeConstant) {
|
||||
TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
// 2x1
|
||||
HloInstruction* const0 = builder.AddInstruction(
|
||||
@ -115,7 +127,7 @@ TEST_F(TransposeFoldingTest, FoldTransposeConstant) {
|
||||
entry_computation->root_instruction()->fused_instructions().size());
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, FuseWithConstantOperands) {
|
||||
TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
|
||||
auto builder = HloComputation::Builder("entry");
|
||||
// (1.0 + 2.0) * (2.0 - 3.0)
|
||||
HloInstruction* const1 = builder.AddInstruction(
|
||||
@ -146,4 +158,168 @@ TEST_F(TransposeFoldingTest, FuseWithConstantOperands) {
|
||||
EXPECT_EQ(6, callee_computation->instructions().size());
|
||||
}
|
||||
|
||||
// Test that a two dimension swap of the kernel gets folded into convolution.
|
||||
TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
|
||||
/*name=*/"x"));
|
||||
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
|
||||
/*name=*/"y"));
|
||||
HloInstruction* transpose_y =
|
||||
builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
|
||||
auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
|
||||
Window window;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_padding_low(0);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_size(
|
||||
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
x->shape(), transpose_y->shape(), window, dnums);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
|
||||
|
||||
HloModule module("test_module");
|
||||
HloComputation* entry_computation =
|
||||
module.AddEntryComputation(builder.Build(conv));
|
||||
FoldTranspose(&module);
|
||||
|
||||
// Instructions after folding: x, y, and the convolution.
|
||||
std::unordered_set<HloInstruction*> instruction_set;
|
||||
for (auto& instruction : entry_computation->instructions()) {
|
||||
instruction_set.insert(instruction.get());
|
||||
}
|
||||
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.size())
|
||||
<< "entry_computation should contain exactly 3 instructions.";
|
||||
HloInstruction* new_conv = *instruction_set.begin();
|
||||
EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
|
||||
EXPECT_EQ(dnums.kernel_input_feature_dimension(),
|
||||
new_conv->convolution_dimension_numbers()
|
||||
.kernel_output_feature_dimension());
|
||||
EXPECT_EQ(dnums.kernel_output_feature_dimension(),
|
||||
new_conv->convolution_dimension_numbers()
|
||||
.kernel_input_feature_dimension());
|
||||
}
|
||||
|
||||
// Test that a complex transpose of the kernel gets folded into convolution.
|
||||
TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
|
||||
/*name=*/"x"));
|
||||
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}),
|
||||
/*name=*/"y"));
|
||||
HloInstruction* transpose_y =
|
||||
builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
|
||||
auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
|
||||
Window window;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_padding_low(0);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_size(
|
||||
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
x->shape(), transpose_y->shape(), window, dnums);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
|
||||
|
||||
HloModule module("test_module");
|
||||
HloComputation* entry_computation =
|
||||
module.AddEntryComputation(builder.Build(conv));
|
||||
FoldTranspose(&module);
|
||||
|
||||
// Instructions after folding: x, y, and the convolution.
|
||||
std::unordered_set<HloInstruction*> instruction_set;
|
||||
for (auto& instruction : entry_computation->instructions()) {
|
||||
instruction_set.insert(instruction.get());
|
||||
}
|
||||
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.size())
|
||||
<< "entry_computation should contain exactly 3 instructions.";
|
||||
HloInstruction* new_conv = *instruction_set.begin();
|
||||
EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
|
||||
EXPECT_EQ(dnums.kernel_input_feature_dimension(),
|
||||
new_conv->convolution_dimension_numbers()
|
||||
.kernel_output_feature_dimension());
|
||||
EXPECT_EQ(dnums.kernel_spatial_dimensions(1),
|
||||
new_conv->convolution_dimension_numbers()
|
||||
.kernel_input_feature_dimension());
|
||||
EXPECT_EQ(
|
||||
dnums.kernel_output_feature_dimension(),
|
||||
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0));
|
||||
EXPECT_EQ(
|
||||
dnums.kernel_spatial_dimensions(0),
|
||||
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
|
||||
}
|
||||
|
||||
// Test that a transpose of the activations does not get folded into
|
||||
// convolution.
|
||||
TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
|
||||
/*name=*/"x"));
|
||||
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
|
||||
/*name=*/"y"));
|
||||
HloInstruction* transpose_x =
|
||||
builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
|
||||
auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
|
||||
Window window;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_padding_low(0);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
transpose_x->shape(), y->shape(), window, dnums);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
|
||||
|
||||
HloModule module("test_module");
|
||||
HloComputation* entry_computation =
|
||||
module.AddEntryComputation(builder.Build(conv));
|
||||
FoldTranspose(&module);
|
||||
|
||||
// Instructions after folding: transpose_x, y, and the convolution.
|
||||
std::unordered_set<HloInstruction*> instruction_set;
|
||||
for (auto& instruction : entry_computation->instructions()) {
|
||||
instruction_set.insert(instruction.get());
|
||||
}
|
||||
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(transpose_x))
|
||||
<< "transpose_x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(conv))
|
||||
<< "transpose_x is not in entry_computation.";
|
||||
CHECK_EQ(0, instruction_set.size())
|
||||
<< "entry_computation should contain exactly 4 instructions.";
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
48
tensorflow/compiler/xla/test.h
Normal file
48
tensorflow/compiler/xla/test.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_
|
||||
#define TENSORFLOW_COMPLIER_XLA_TEST_H_
|
||||
|
||||
// This header includes gmock.h and enables the use of gmock matchers in tests
|
||||
// in third_party/tensorflow/compiler/xla.
|
||||
//
|
||||
// Test including this header can use the macros EXPECT_THAT(...) and
|
||||
// ASSERT_THAT(...) in combination with gmock matchers.
|
||||
// Example:
|
||||
// std::vector<int> vec = Foo();
|
||||
// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3));
|
||||
//
|
||||
// For more details on gmock matchers see:
|
||||
// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers
|
||||
//
|
||||
// The advantages of using gmock matchers instead of self defined matchers are
|
||||
// better error messages, more maintainable tests and more test coverage.
|
||||
//
|
||||
// Note that while the use of gmock matchers is allowed in the xla project, the
|
||||
// use of mocks is disallowed in the whole tensorflow project!
|
||||
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
|
||||
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
|
||||
#include "testing/base/public/gmock.h"
|
||||
#else
|
||||
#include <gmock/gmock-generated-matchers.h>
|
||||
#include <gmock/gmock-matchers.h>
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_
|
@ -55,7 +55,7 @@ class WeightedQuantilesBuffer {
|
||||
: max_size_(std::min(block_size << 1, max_elements)) {
|
||||
QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
|
||||
<< ", " << max_elements << ")";
|
||||
map_.reserve(max_size_);
|
||||
vec_.reserve(max_size_);
|
||||
}
|
||||
|
||||
// Disallow copying as it's semantically non-sensical in the Squawd algorithm
|
||||
@ -77,42 +77,48 @@ class WeightedQuantilesBuffer {
|
||||
return;
|
||||
}
|
||||
|
||||
// Insert entry to map if not already present else
|
||||
// accumulate the new weight.
|
||||
auto result = map_.insert(BufferMapEntry(value, weight));
|
||||
if (!result.second) {
|
||||
result.first->second += weight;
|
||||
}
|
||||
// Push back the entry to the buffer.
|
||||
vec_.push_back(BufferEntry(value, weight));
|
||||
}
|
||||
|
||||
// Returns a sorted vector view of the base buffer. Callers should
|
||||
// minimize how often this is called, ideally only right after the buffer
|
||||
// becomes full.
|
||||
std::vector<BufferEntry> GenerateEntryList() const {
|
||||
// Returns a sorted vector view of the base buffer and clears the buffer.
|
||||
// Callers should minimize how often this is called, ideally only right after
|
||||
// the buffer becomes full.
|
||||
std::vector<BufferEntry> GenerateEntryList() {
|
||||
std::vector<BufferEntry> ret;
|
||||
ret.reserve(map_.size());
|
||||
std::transform(map_.begin(), map_.end(), std::back_inserter(ret),
|
||||
[](const BufferMapEntry& map_entry) {
|
||||
return BufferEntry(map_entry.first, map_entry.second);
|
||||
});
|
||||
if (vec_.size() == 0) {
|
||||
return ret;
|
||||
}
|
||||
ret.swap(vec_);
|
||||
vec_.reserve(max_size_);
|
||||
std::sort(ret.begin(), ret.end());
|
||||
size_t num_entries = 0;
|
||||
for (size_t i = 1; i < ret.size(); ++i) {
|
||||
if (ret[i].value != ret[i - 1].value) {
|
||||
BufferEntry tmp = ret[i];
|
||||
++num_entries;
|
||||
ret[num_entries] = tmp;
|
||||
} else {
|
||||
ret[num_entries].weight += ret[i].weight;
|
||||
}
|
||||
}
|
||||
ret.resize(num_entries + 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int64 Size() const { return map_.size(); }
|
||||
bool IsFull() const { return map_.size() >= max_size_; }
|
||||
void Clear() { map_.clear(); }
|
||||
int64 Size() const { return vec_.size(); }
|
||||
bool IsFull() const { return vec_.size() >= max_size_; }
|
||||
void Clear() { vec_.clear(); }
|
||||
|
||||
private:
|
||||
using BufferMap = typename std::unordered_map<ValueType, WeightType>;
|
||||
using BufferMapEntry = typename BufferMap::value_type;
|
||||
using BufferVector = typename std::vector<BufferEntry>;
|
||||
|
||||
// Comparison function.
|
||||
static constexpr decltype(CompareFn()) kCompFn = CompareFn();
|
||||
|
||||
// Base buffer.
|
||||
size_t max_size_;
|
||||
BufferMap map_;
|
||||
BufferVector vec_;
|
||||
};
|
||||
|
||||
template <typename ValueType, typename WeightType, typename CompareFn>
|
||||
|
@ -69,47 +69,32 @@ TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
|
||||
expected.emplace_back(2, 4);
|
||||
expected.emplace_back(5, 9);
|
||||
|
||||
// At this point, we have a compaction and duplicate entry 2 is merged.
|
||||
EXPECT_FALSE(buffer.IsFull());
|
||||
EXPECT_EQ(buffer.GenerateEntryList(), expected);
|
||||
|
||||
// Push another unique entry.
|
||||
buffer.PushEntry(3, 2);
|
||||
// At this point, we have pushed 4 entries and we expect the buffer to be
|
||||
// full.
|
||||
EXPECT_TRUE(buffer.IsFull());
|
||||
EXPECT_EQ(buffer.GenerateEntryList(), expected);
|
||||
EXPECT_FALSE(buffer.IsFull());
|
||||
}
|
||||
|
||||
TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
|
||||
// buffer capacity is 4.
|
||||
Buffer buffer(2, 100);
|
||||
buffer.PushEntry(5, 9);
|
||||
buffer.PushEntry(2, 3);
|
||||
buffer.PushEntry(-1, 7);
|
||||
buffer.PushEntry(2, 1);
|
||||
|
||||
std::vector<BufferEntry> expected;
|
||||
expected.emplace_back(-1, 7);
|
||||
expected.emplace_back(2, 4);
|
||||
expected.emplace_back(5, 9);
|
||||
|
||||
// At this point, we have pushed 4 entries and we expect the buffer to be
|
||||
// full.
|
||||
EXPECT_TRUE(buffer.IsFull());
|
||||
// Can't push any more entries before clearing.
|
||||
EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
|
||||
}
|
||||
|
||||
TEST_F(WeightedQuantilesBufferTest, RandomizedPush) {
|
||||
// buffer capacity is 6.
|
||||
Buffer buffer(3, 100);
|
||||
std::array<double, 5> elements = {{1.1, 2.3, 5.1, 8.0, 12.6}};
|
||||
std::array<double, elements.size()> counts;
|
||||
counts.fill(0.0);
|
||||
|
||||
random::PhiloxRandom philox(13);
|
||||
random::SimplePhilox rand(&philox);
|
||||
|
||||
for (int iters = 10000; iters-- > 0; --iters) {
|
||||
// Add entry.
|
||||
int32 picked_idx = rand.Uniform(elements.size());
|
||||
buffer.PushEntry(elements[picked_idx], 1.0);
|
||||
++counts[picked_idx];
|
||||
|
||||
// We can't fill buffer with a number of unique elements < capacity.
|
||||
EXPECT_FALSE(buffer.IsFull());
|
||||
}
|
||||
|
||||
// Ensure we didn't lose any information.
|
||||
std::vector<BufferEntry> expected;
|
||||
for (int i = 0; i < elements.size(); ++i) {
|
||||
if (counts[i] > 0) {
|
||||
expected.emplace_back(elements[i], counts[i]);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(buffer.GenerateEntryList(), expected);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -91,12 +91,11 @@ class WeightedQuantilesStream {
|
||||
// and push weighted quantile summary up the level chain.
|
||||
if (buffer_.IsFull()) {
|
||||
PushBuffer(buffer_);
|
||||
buffer_.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Pushes full buffer while maintaining approximation error invariants.
|
||||
void PushBuffer(const Buffer& buffer) {
|
||||
void PushBuffer(Buffer& buffer) {
|
||||
// Validate state.
|
||||
QCHECK(!finalized_) << "Finalize() already called.";
|
||||
|
||||
@ -124,7 +123,6 @@ class WeightedQuantilesStream {
|
||||
|
||||
// Flush any remaining buffer elements.
|
||||
PushBuffer(buffer_);
|
||||
buffer_.Clear();
|
||||
|
||||
// Create final merged summary.
|
||||
local_summary_.Clear();
|
||||
|
@ -91,9 +91,10 @@ TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) {
|
||||
}
|
||||
|
||||
TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) {
|
||||
const auto entry_list = buffer1_->GenerateEntryList();
|
||||
for (int new_size = 9; new_size >= 2; --new_size) {
|
||||
Summary summary;
|
||||
summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
|
||||
summary.BuildFromBufferEntries(entry_list);
|
||||
summary.Compress(new_size);
|
||||
|
||||
// Expect a max approximation error of 1 / n
|
||||
@ -161,10 +162,12 @@ TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) {
|
||||
|
||||
TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
|
||||
// Create two separate summaries and merge.
|
||||
const auto list_1 = buffer1_->GenerateEntryList();
|
||||
const auto list_2 = buffer2_->GenerateEntryList();
|
||||
Summary summary1;
|
||||
summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
|
||||
summary1.BuildFromBufferEntries(list_1);
|
||||
Summary summary2;
|
||||
summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList());
|
||||
summary2.BuildFromBufferEntries(list_2);
|
||||
|
||||
// Merge summary 2 into 1 and verify.
|
||||
summary1.Merge(summary2);
|
||||
@ -178,7 +181,7 @@ TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
|
||||
EXPECT_EQ(summary1.Size(), 14); // 14 unique values.
|
||||
|
||||
// Merge summary 1 into 2 and verify same result.
|
||||
summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
|
||||
summary1.BuildFromBufferEntries(list_1);
|
||||
summary2.Merge(summary1);
|
||||
EXPECT_EQ(summary2.ApproximationError(), 0.0);
|
||||
EXPECT_EQ(summary2.MinValue(),
|
||||
|
@ -212,7 +212,6 @@ add_python_module("tensorflow/tensorboard")
|
||||
add_python_module("tensorflow/tensorboard/backend")
|
||||
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
||||
add_python_module("tensorflow/tensorboard/plugins")
|
||||
add_python_module("tensorflow/tensorboard/plugins/debugger")
|
||||
add_python_module("tensorflow/tensorboard/plugins/projector")
|
||||
add_python_module("tensorflow/tensorboard/plugins/text")
|
||||
add_python_module("tensorflow/tensorboard/scripts")
|
||||
|
@ -202,6 +202,7 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":factorization_py",
|
||||
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||
":factorization_ops_test_utils_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -190,7 +190,8 @@ class WALSModel(object):
|
||||
num_col_shards=1,
|
||||
row_weights=1,
|
||||
col_weights=1,
|
||||
use_factors_weights_cache=True):
|
||||
use_factors_weights_cache=True,
|
||||
use_gramian_cache=True):
|
||||
"""Creates model for WALS matrix factorization.
|
||||
|
||||
Args:
|
||||
@ -224,6 +225,8 @@ class WALSModel(object):
|
||||
col_weights: See row_weights.
|
||||
use_factors_weights_cache: When True, the factors and weights will be
|
||||
cached on the workers before the updates start. Defaults to True.
|
||||
use_gramian_cache: When True, the Gramians will be cached on the workers
|
||||
before the updates start. Defaults to True.
|
||||
"""
|
||||
self._input_rows = input_rows
|
||||
self._input_cols = input_cols
|
||||
@ -243,6 +246,7 @@ class WALSModel(object):
|
||||
self._num_col_shards,
|
||||
"col_weights")
|
||||
self._use_factors_weights_cache = use_factors_weights_cache
|
||||
self._use_gramian_cache = use_gramian_cache
|
||||
self._row_factors = self._create_factors(self._input_rows,
|
||||
self._n_components,
|
||||
self._num_row_shards, row_init,
|
||||
@ -495,10 +499,13 @@ class WALSModel(object):
|
||||
"""Creates local cache of factors, weights and gramian for rows and columns.
|
||||
|
||||
Note that currently the caching strategy is as follows:
|
||||
When initiating a row(column) update, the column(row) gramian is computed
|
||||
and cached while the row gramian is reset; optionally, column(row) factors
|
||||
and weights are cached and row(column) factors and weights are reset when
|
||||
use_factors_weights_cache is True.
|
||||
When initiating a row (resp. column) update:
|
||||
- The column (resp. row) gramian is computed.
|
||||
- Optionally, if use_gramian_cache is True, the column (resp. row) Gramian
|
||||
is cached, while the row (resp. column) gramian is reset.
|
||||
- Optionally, if use_factors_weights_cache is True, the column (resp. row)
|
||||
factors and weights are cached, while the row (resp. column) factors and
|
||||
weights are reset.
|
||||
"""
|
||||
|
||||
(self._row_factors_cache, row_factors_cache_init,
|
||||
@ -515,18 +522,20 @@ class WALSModel(object):
|
||||
self._row_weights,
|
||||
"row_wt_cache",
|
||||
pass_through=not self._use_factors_weights_cache)
|
||||
|
||||
(self._col_wt_cache, col_wt_cache_init, _) = self._cached_copy(
|
||||
self._col_weights,
|
||||
"col_wt_cache",
|
||||
pass_through=not self._use_factors_weights_cache)
|
||||
|
||||
(self._row_gramian_cache, row_gramian_cache_init,
|
||||
row_gramian_cache_reset) = self._cached_copy(
|
||||
self._row_gramian, "row_gramian_cache", pass_through=False)
|
||||
self._row_gramian,
|
||||
"row_gramian_cache",
|
||||
pass_through=not self._use_gramian_cache)
|
||||
(self._col_gramian_cache, col_gramian_cache_init,
|
||||
col_gramian_cache_reset) = self._cached_copy(
|
||||
self._col_gramian, "col_gramian_cache", pass_through=False)
|
||||
self._col_gramian,
|
||||
"col_gramian_cache",
|
||||
pass_through=not self._use_gramian_cache)
|
||||
|
||||
self._row_updates_init = control_flow_ops.group(col_factors_cache_init,
|
||||
row_factors_cache_reset,
|
||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.factorization.python.ops import factorization_ops
|
||||
from tensorflow.contrib.framework.python.ops import variables as framework_variables
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -221,3 +224,321 @@ class _SweepHook(session_run_hook.SessionRunHook):
|
||||
self._is_sweep_done = run_values.results[0]
|
||||
logging.info("Partial fit done.")
|
||||
|
||||
|
||||
def _wals_factorization_model_function(features, labels, mode, params):
|
||||
"""Model function for the WALSFactorization estimator.
|
||||
|
||||
Args:
|
||||
features: Dictionary of features. See WALSMatrixFactorization.
|
||||
labels: Must be None.
|
||||
mode: A model_fn.ModeKeys object.
|
||||
params: Dictionary of parameters containing arguments passed to the
|
||||
WALSMatrixFactorization constructor.
|
||||
|
||||
Returns:
|
||||
A ModelFnOps object.
|
||||
"""
|
||||
assert labels is None
|
||||
use_factors_weights_cache = (
|
||||
params["use_factors_weights_cache_for_training"]
|
||||
and mode == model_fn.ModeKeys.TRAIN)
|
||||
use_gramian_cache = (
|
||||
params["use_gramian_cache_for_training"]
|
||||
and mode == model_fn.ModeKeys.TRAIN)
|
||||
model = factorization_ops.WALSModel(
|
||||
params["num_rows"],
|
||||
params["num_cols"],
|
||||
params["embedding_dimension"],
|
||||
unobserved_weight=params["unobserved_weight"],
|
||||
regularization=params["regularization_coeff"],
|
||||
row_init=params["row_init"],
|
||||
col_init=params["col_init"],
|
||||
num_row_shards=params["num_row_shards"],
|
||||
num_col_shards=params["num_col_shards"],
|
||||
row_weights=params["row_weights"],
|
||||
col_weights=params["col_weights"],
|
||||
use_factors_weights_cache=use_factors_weights_cache,
|
||||
use_gramian_cache=use_gramian_cache)
|
||||
|
||||
# Get input rows and cols. We either update rows or columns depending on
|
||||
# the value of row_sweep, which is maintained using a session hook
|
||||
input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
|
||||
input_cols = features[WALSMatrixFactorization.INPUT_COLS]
|
||||
input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0])
|
||||
input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0])
|
||||
|
||||
# Train ops, controlled using the SweepHook
|
||||
# We need to run the following ops:
|
||||
# Before a row sweep:
|
||||
# row_update_prep_gramian_op
|
||||
# initialize_row_update_op
|
||||
# During a row sweep:
|
||||
# update_row_factors_op
|
||||
# Before a col sweep:
|
||||
# col_update_prep_gramian_op
|
||||
# initialize_col_update_op
|
||||
# During a col sweep:
|
||||
# update_col_factors_op
|
||||
|
||||
is_row_sweep_var = variables.Variable(
|
||||
True, "is_row_sweep",
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
# The row sweep is determined by is_row_sweep_var (controlled by the
|
||||
# sweep_hook) in TRAIN mode, and manually in EVAL mode.
|
||||
is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
|
||||
if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var)
|
||||
|
||||
def update_row_factors():
|
||||
return model.update_row_factors(sp_input=input_rows, transpose_input=False)
|
||||
def update_col_factors():
|
||||
return model.update_col_factors(sp_input=input_cols, transpose_input=True)
|
||||
_, train_op, loss = control_flow_ops.cond(
|
||||
is_row_sweep, update_row_factors, update_col_factors)
|
||||
|
||||
row_prep_ops = [model.row_update_prep_gramian_op,
|
||||
model.initialize_row_update_op]
|
||||
col_prep_ops = [model.col_update_prep_gramian_op,
|
||||
model.initialize_col_update_op]
|
||||
cache_init_ops = [model.worker_init]
|
||||
|
||||
sweep_hook = _SweepHook(
|
||||
is_row_sweep_var,
|
||||
train_op,
|
||||
params["num_rows"],
|
||||
params["num_cols"],
|
||||
input_row_indices,
|
||||
input_col_indices,
|
||||
row_prep_ops,
|
||||
col_prep_ops,
|
||||
cache_init_ops,
|
||||
)
|
||||
|
||||
# Prediction ops (only return predictions in INFER mode)
|
||||
predictions = {}
|
||||
if mode == model_fn.ModeKeys.INFER:
|
||||
project_row = features[WALSMatrixFactorization.PROJECT_ROW]
|
||||
projection_weights = features.get(
|
||||
WALSMatrixFactorization.PROJECTION_WEIGHTS)
|
||||
def get_row_projection():
|
||||
return model.project_row_factors(
|
||||
sp_input=input_rows,
|
||||
projection_weights=projection_weights,
|
||||
transpose_input=False)
|
||||
def get_col_projection():
|
||||
return model.project_col_factors(
|
||||
sp_input=input_cols,
|
||||
projection_weights=projection_weights,
|
||||
transpose_input=True)
|
||||
|
||||
predictions[WALSMatrixFactorization.PROJECTION_RESULT] = (
|
||||
control_flow_ops.cond(
|
||||
project_row, get_row_projection, get_col_projection))
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=mode,
|
||||
predictions=predictions,
|
||||
loss=loss,
|
||||
eval_metric_ops={},
|
||||
train_op=train_op,
|
||||
training_hooks=[sweep_hook])
|
||||
|
||||
|
||||
class WALSMatrixFactorization(estimator.Estimator):
|
||||
"""An Estimator for Weighted Matrix Factorization, using the WALS method.
|
||||
|
||||
WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
|
||||
factorization. It computes a low-rank approximation of a given sparse (n x m)
|
||||
matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix
|
||||
and V is a (m x k) matrix. Here k is the rank of the approximation, also
|
||||
called the embedding dimension. We refer to U as the row factors, and V as the
|
||||
column factors.
|
||||
See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
|
||||
formulation.
|
||||
|
||||
The training proceeds in sweeps: during a row_sweep, we fix V and solve for U.
|
||||
During a column sweep, we fix U and solve for V. Each one of these problems is
|
||||
an unconstrained quadratic minimization problem and can be solved exactly (it
|
||||
can also be solved in mini-batches, since the solution decouples nicely).
|
||||
The alternating between sweeps is achieved by using a hook during training,
|
||||
which is responsible for keeping track of the sweeps and running preparation
|
||||
ops at the beginning of each sweep. It also updates the global_step variable,
|
||||
which keeps track of the number of batches processed since the beginning of
|
||||
training.
|
||||
The current implementation assumes that the training is run on a single
|
||||
machine, and will fail if config.num_worker_replicas is not equal to one.
|
||||
Training is done by calling self.fit(input_fn=input_fn), where input_fn
|
||||
provides two queues: one for rows of the input matrix, and one for rows of the
|
||||
transposed input matrix (i.e. columns of the original matrix). Note that
|
||||
during a row sweep, only row batches are processed (ignoring column batches)
|
||||
and vice-versa.
|
||||
|
||||
For prediction, given a new set of input rows A' (e.g. new rows of the A
|
||||
matrix), we compute a corresponding set of row factors U', such that U' * V^T
|
||||
is a good approximation of A'. We call this operation a row projection. A
|
||||
similar operation is defined for columns.
|
||||
Projection is done by calling self.get_projections(input_fn=input_fn), where
|
||||
input_fn satisfies the constraints given below.
|
||||
|
||||
The input functions must satisfy the following constraints: Calling input_fn
|
||||
must return a tuple (features, labels) where labels is None, and features is
|
||||
a dict containing the following keys:
|
||||
TRAIN:
|
||||
- WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
|
||||
Rows of the input matrix to process (or to project).
|
||||
- WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
|
||||
Columns of the input matrix to process (or to project), transposed.
|
||||
INFER:
|
||||
- WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
|
||||
Rows to project.
|
||||
- WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
|
||||
Columns to project.
|
||||
- WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
|
||||
the rows or columns.
|
||||
- WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor
|
||||
(vector). The weights to use in the projection.
|
||||
EVAL:
|
||||
- WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
|
||||
Rows to project.
|
||||
- WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
|
||||
Columns to project.
|
||||
- WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
|
||||
the rows or columns.
|
||||
"""
|
||||
# Keys to be used in model_fn
|
||||
# Features keys
|
||||
INPUT_ROWS = "input_rows"
|
||||
INPUT_COLS = "input_cols"
|
||||
PROJECT_ROW = "project_row"
|
||||
PROJECTION_WEIGHTS = "projection_weights"
|
||||
# Predictions key
|
||||
PROJECTION_RESULT = "projection"
|
||||
|
||||
def __init__(self,
|
||||
num_rows,
|
||||
num_cols,
|
||||
embedding_dimension,
|
||||
unobserved_weight=0.1,
|
||||
regularization_coeff=None,
|
||||
row_init="random",
|
||||
col_init="random",
|
||||
num_row_shards=1,
|
||||
num_col_shards=1,
|
||||
row_weights=1,
|
||||
col_weights=1,
|
||||
use_factors_weights_cache_for_training=True,
|
||||
use_gramian_cache_for_training=True,
|
||||
model_dir=None,
|
||||
config=None):
|
||||
"""Creates a model for matrix factorization using the WALS method.
|
||||
|
||||
Args:
|
||||
num_rows: Total number of rows for input matrix.
|
||||
num_cols: Total number of cols for input matrix.
|
||||
embedding_dimension: Dimension to use for the factors.
|
||||
unobserved_weight: Weight of the unobserved entries of matrix.
|
||||
regularization_coeff: Weight of the L2 regularization term. Defaults to
|
||||
None, in which case the problem is not regularized.
|
||||
row_init: Initializer for row factor. Must be either:
|
||||
- A tensor: The row factor matrix is initialized to this tensor,
|
||||
- A numpy constant,
|
||||
- "random": The rows are initialized using a normal distribution.
|
||||
col_init: Initializer for column factor. See row_init.
|
||||
num_row_shards: Number of shards to use for the row factors.
|
||||
num_col_shards: Number of shards to use for the column factors.
|
||||
row_weights: Must be in one of the following three formats:
|
||||
- None: In this case, the weight of every entry is the unobserved_weight
|
||||
and the problem simplifies to ALS. Note that, in this case,
|
||||
col_weights must also be set to "None".
|
||||
- List of lists of non-negative scalars, of the form
|
||||
[[w_0, w_1, ...], [w_k, ... ], [...]],
|
||||
where the number of inner lists equal to the number of row factor
|
||||
shards and the elements in each inner list are the weights for the
|
||||
rows of that shard. In this case,
|
||||
w_ij = unonbserved_weight + row_weights[i] * col_weights[j].
|
||||
- A non-negative scalar: This value is used for all row weights.
|
||||
Note that it is allowed to have row_weights as a list and col_weights
|
||||
as a scalar, or vice-versa.
|
||||
col_weights: See row_weights.
|
||||
use_factors_weights_cache_for_training: Boolean, whether the factors and
|
||||
weights will be cached on the workers before the updates start, during
|
||||
training. Defaults to True.
|
||||
Note that caching is disabled during prediction.
|
||||
use_gramian_cache_for_training: Boolean, whether the Gramians will be
|
||||
cached on the workers before the updates start, during training.
|
||||
Defaults to True. Note that caching is disabled during prediction.
|
||||
model_dir: The directory to save the model results and log files.
|
||||
config: A Configuration object. See Estimator.
|
||||
|
||||
Raises:
|
||||
ValueError: If config.num_worker_replicas is strictly greater than one.
|
||||
The current implementation only supports running on a single worker.
|
||||
"""
|
||||
# TODO(walidk): Support distributed training.
|
||||
# TODO(walidk): Support power-law based weight computation.
|
||||
# TODO(walidk): Add factor lookup by indices, with caching.
|
||||
# TODO(walidk): Support caching during prediction.
|
||||
|
||||
params = {
|
||||
"num_rows": num_rows,
|
||||
"num_cols": num_cols,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"unobserved_weight": unobserved_weight,
|
||||
"regularization_coeff": regularization_coeff,
|
||||
"row_init": row_init,
|
||||
"col_init": col_init,
|
||||
"num_row_shards": num_row_shards,
|
||||
"num_col_shards": num_col_shards,
|
||||
"row_weights": row_weights,
|
||||
"col_weights": col_weights,
|
||||
"use_factors_weights_cache_for_training":
|
||||
use_factors_weights_cache_for_training,
|
||||
"use_gramian_cache_for_training": use_gramian_cache_for_training
|
||||
}
|
||||
self._row_factors_names = ["row_factors_shard_%d" % i
|
||||
for i in range(num_row_shards)]
|
||||
self._col_factors_names = ["col_factors_shard_%d" % i
|
||||
for i in range(num_col_shards)]
|
||||
|
||||
super(WALSMatrixFactorization, self).__init__(
|
||||
model_fn=_wals_factorization_model_function,
|
||||
params=params,
|
||||
model_dir=model_dir,
|
||||
config=config)
|
||||
|
||||
if self._config is not None and self._config.num_worker_replicas > 1:
|
||||
raise ValueError("WALSMatrixFactorization must be run on a single worker "
|
||||
"replica.")
|
||||
|
||||
def get_row_factors(self):
|
||||
"""Returns the row factors of the model, loading them from checkpoint.
|
||||
|
||||
Should only be run after training.
|
||||
|
||||
Returns:
|
||||
A list of the row factors of the model.
|
||||
"""
|
||||
return [self.get_variable_value(name) for name in self._row_factors_names]
|
||||
|
||||
def get_col_factors(self):
|
||||
"""Returns the column factors of the model, loading them from checkpoint.
|
||||
|
||||
Should only be run after training.
|
||||
|
||||
Returns:
|
||||
A list of the column factors of the model.
|
||||
"""
|
||||
return [self.get_variable_value(name) for name in self._col_factors_names]
|
||||
|
||||
def get_projections(self, input_fn):
|
||||
"""Computes the projections of the rows or columns given in input_fn.
|
||||
|
||||
Runs predict() with the given input_fn, and returns the results. Should only
|
||||
be run after training.
|
||||
|
||||
Args:
|
||||
input_fn: Input function which specifies the rows or columns to project.
|
||||
Returns:
|
||||
A generator of the projected factors.
|
||||
"""
|
||||
return (result[WALSMatrixFactorization.PROJECTION_RESULT]
|
||||
for result in self.predict(input_fn=input_fn))
|
||||
|
@ -18,16 +18,311 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
|
||||
from tensorflow.contrib.factorization.python.ops import wals as wals_lib
|
||||
from tensorflow.contrib.learn.python.learn import run_config
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
class WALSMatrixFactorizationTest(test.TestCase):
|
||||
INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
|
||||
|
||||
def np_array_to_sparse(self, np_array):
|
||||
"""Transforms an np.array to a tf.SparseTensor."""
|
||||
return factorization_ops_test_utils.np_matrix_to_tf_sparse(np_array)
|
||||
|
||||
def calculate_loss(self):
|
||||
"""Calculates the loss of the current (trained) model."""
|
||||
current_rows = embedding_ops.embedding_lookup(
|
||||
self._model.get_row_factors(), math_ops.range(self._num_rows),
|
||||
partition_strategy='div')
|
||||
current_cols = embedding_ops.embedding_lookup(
|
||||
self._model.get_col_factors(), math_ops.range(self._num_cols),
|
||||
partition_strategy='div')
|
||||
row_wts = embedding_ops.embedding_lookup(
|
||||
self._row_weights, math_ops.range(self._num_rows),
|
||||
partition_strategy='div')
|
||||
col_wts = embedding_ops.embedding_lookup(
|
||||
self._col_weights, math_ops.range(self._num_cols),
|
||||
partition_strategy='div')
|
||||
sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX)
|
||||
return factorization_ops_test_utils.calculate_loss(
|
||||
sp_inputs, current_rows, current_cols, self._regularization_coeff,
|
||||
self._unobserved_weight, row_wts, col_wts)
|
||||
|
||||
# TODO(walidk): Replace with input_reader_utils functions once open sourced.
|
||||
def remap_sparse_tensor_rows(self, sp_x, row_ids, shape):
|
||||
"""Remaps the row ids of a tf.SparseTensor."""
|
||||
old_row_ids, old_col_ids = array_ops.split(
|
||||
value=sp_x.indices, num_or_size_splits=2, axis=1)
|
||||
new_row_ids = array_ops.gather(row_ids, old_row_ids)
|
||||
new_indices = array_ops.concat([new_row_ids, old_col_ids], 1)
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=new_indices, values=sp_x.values, dense_shape=shape)
|
||||
|
||||
# TODO(walidk): Add an option to randomize inputs.
|
||||
def input_fn(self, np_matrix, batch_size, project_row=None,
|
||||
projection_weights=None, col_ids=None):
|
||||
"""Returns an input_fn that selects row and col batches from np_matrix."""
|
||||
def extract_features(row_batch, col_batch, shape):
|
||||
row_ids = row_batch[0]
|
||||
col_ids = col_batch[0]
|
||||
rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
|
||||
cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
|
||||
features = {
|
||||
wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
|
||||
wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
|
||||
}
|
||||
return features
|
||||
|
||||
def _fn():
|
||||
num_rows = np.shape(np_matrix)[0]
|
||||
num_cols = np.shape(np_matrix)[1]
|
||||
row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
|
||||
col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
|
||||
sp_mat = self.np_array_to_sparse(np_matrix)
|
||||
sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
|
||||
row_batch = input_lib.batch(
|
||||
[row_ids, sp_mat],
|
||||
batch_size=min(batch_size, num_rows),
|
||||
capacity=10,
|
||||
enqueue_many=True)
|
||||
col_batch = input_lib.batch(
|
||||
[col_ids, sp_mat_t],
|
||||
batch_size=min(batch_size, num_cols),
|
||||
capacity=10,
|
||||
enqueue_many=True)
|
||||
|
||||
features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
|
||||
if projection_weights is not None:
|
||||
weights_batch = input_lib.batch(
|
||||
projection_weights,
|
||||
batch_size=batch_size,
|
||||
capacity=10,
|
||||
enqueue_many=True)
|
||||
features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
|
||||
weights_batch)
|
||||
if project_row is not None:
|
||||
features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
|
||||
constant_op.constant(project_row))
|
||||
|
||||
labels = None
|
||||
return features, labels
|
||||
|
||||
return _fn
|
||||
|
||||
@property
|
||||
def row_steps(self):
|
||||
return np.ceil(self._num_rows / self.batch_size)
|
||||
|
||||
@property
|
||||
def col_steps(self):
|
||||
return np.ceil(self._num_cols / self.batch_size)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def use_cache(self):
|
||||
return False
|
||||
|
||||
def setUp(self):
|
||||
self._num_rows = 5
|
||||
self._num_cols = 7
|
||||
self._embedding_dimension = 3
|
||||
self._unobserved_weight = 0.1
|
||||
self._num_row_shards = 2
|
||||
self._num_col_shards = 3
|
||||
self._regularization_coeff = 0.01
|
||||
self._col_init = [
|
||||
# Shard 0.
|
||||
[[-0.36444709, -0.39077035, -0.32528427],
|
||||
[1.19056475, 0.07231052, 2.11834812],
|
||||
[0.93468881, -0.71099287, 1.91826844]],
|
||||
# Shard 1.
|
||||
[[1.18160152, 1.52490723, -0.50015002],
|
||||
[1.82574749, -0.57515913, -1.32810032]],
|
||||
# Shard 2.
|
||||
[[-0.15515432, -0.84675711, 0.13097958],
|
||||
[-0.9246484, 0.69117504, 1.2036494]],
|
||||
]
|
||||
self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]]
|
||||
self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]
|
||||
|
||||
# Values of row and column factors after running one iteration or factor
|
||||
# updates.
|
||||
self._row_factors_0 = [[0.097689, -0.219293, -0.020780],
|
||||
[0.50842, 0.64626, 0.22364],
|
||||
[0.401159, -0.046558, -0.192854]]
|
||||
self._row_factors_1 = [[1.20597, -0.48025, 0.35582],
|
||||
[1.5564, 1.2528, 1.0528]]
|
||||
self._col_factors_0 = [[2.4725, -1.2950, -1.9980],
|
||||
[0.44625, 1.50771, 1.27118],
|
||||
[1.39801, -2.10134, 0.73572]]
|
||||
self._col_factors_1 = [[3.36509, -0.66595, -3.51208],
|
||||
[0.57191, 1.59407, 1.33020]]
|
||||
self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
|
||||
[0.57366, 1.83729, 1.26798]]
|
||||
self._model = wals_lib.WALSMatrixFactorization(
|
||||
self._num_rows,
|
||||
self._num_cols,
|
||||
self._embedding_dimension,
|
||||
self._unobserved_weight,
|
||||
col_init=self._col_init,
|
||||
regularization_coeff=self._regularization_coeff,
|
||||
num_row_shards=self._num_row_shards,
|
||||
num_col_shards=self._num_col_shards,
|
||||
row_weights=self._row_weights,
|
||||
col_weights=self._col_weights,
|
||||
use_factors_weights_cache_for_training=self.use_cache,
|
||||
use_gramian_cache_for_training=self.use_cache)
|
||||
|
||||
def test_fit(self):
|
||||
# Row sweep.
|
||||
input_fn = self.input_fn(np_matrix=self.INPUT_MATRIX,
|
||||
batch_size=self.batch_size)
|
||||
self._model.fit(input_fn=input_fn, steps=self.row_steps)
|
||||
row_factors = self._model.get_row_factors()
|
||||
self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
|
||||
self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)
|
||||
|
||||
# Col sweep.
|
||||
# Running fit a second time will resume training from the checkpoint.
|
||||
input_fn = self.input_fn(np_matrix=self.INPUT_MATRIX,
|
||||
batch_size=self.batch_size)
|
||||
self._model.fit(input_fn=input_fn, steps=self.col_steps)
|
||||
col_factors = self._model.get_col_factors()
|
||||
self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
|
||||
self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
|
||||
self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
|
||||
|
||||
def test_predict(self):
|
||||
input_fn = self.input_fn(np_matrix=self.INPUT_MATRIX,
|
||||
batch_size=self.batch_size)
|
||||
# Project rows 1 and 4 from the input matrix.
|
||||
proj_input_fn = self.input_fn(
|
||||
np_matrix=self.INPUT_MATRIX[[1, 4], :],
|
||||
batch_size=2,
|
||||
project_row=True,
|
||||
projection_weights=[[0.2, 0.5]])
|
||||
|
||||
self._model.fit(input_fn=input_fn, steps=self.row_steps)
|
||||
projections = self._model.get_projections(proj_input_fn)
|
||||
projected_rows = list(itertools.islice(projections, 2))
|
||||
|
||||
self.assertAllClose(
|
||||
projected_rows,
|
||||
[self._row_factors_0[1], self._row_factors_1[1]],
|
||||
atol=1e-3)
|
||||
|
||||
# Project columns 5, 3, 1 from the input matrix.
|
||||
proj_input_fn = self.input_fn(
|
||||
np_matrix=self.INPUT_MATRIX[:, [5, 3, 1]],
|
||||
batch_size=3,
|
||||
project_row=False,
|
||||
projection_weights=[[0.6, 0.4, 0.2]])
|
||||
|
||||
self._model.fit(input_fn=input_fn, steps=self.col_steps)
|
||||
projections = self._model.get_projections(proj_input_fn)
|
||||
projected_cols = list(itertools.islice(projections, 3))
|
||||
self.assertAllClose(
|
||||
projected_cols,
|
||||
[self._col_factors_2[0], self._col_factors_1[0],
|
||||
self._col_factors_0[1]],
|
||||
atol=1e-3)
|
||||
|
||||
def test_eval(self):
|
||||
# Do a row sweep then evaluate the model on row inputs.
|
||||
# The evaluate function returns the loss of the projected rows, but since
|
||||
# projection is idempotent, the eval loss must match the model loss.
|
||||
input_fn = self.input_fn(np_matrix=self.INPUT_MATRIX,
|
||||
batch_size=self.batch_size)
|
||||
self._model.fit(input_fn=input_fn, steps=self.row_steps)
|
||||
eval_input_fn_row = self.input_fn(np_matrix=self.INPUT_MATRIX, batch_size=1,
|
||||
project_row=True)
|
||||
loss = self._model.evaluate(
|
||||
input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
|
||||
|
||||
with self.test_session():
|
||||
true_loss = self.calculate_loss()
|
||||
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
msg="""After row update, eval loss = {}, does not match the true
|
||||
loss = {}.""".format(loss, true_loss))
|
||||
|
||||
# Do a col sweep then evaluate the model on col inputs.
|
||||
self._model.fit(input_fn=input_fn, steps=self.col_steps)
|
||||
eval_input_fn_col = self.input_fn(np_matrix=self.INPUT_MATRIX, batch_size=1,
|
||||
project_row=False)
|
||||
loss = self._model.evaluate(
|
||||
input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
|
||||
|
||||
with self.test_session():
|
||||
true_loss = self.calculate_loss()
|
||||
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
msg="""After row update, eval loss = {}, does not match the true
|
||||
loss = {}.""".format(loss, true_loss))
|
||||
|
||||
|
||||
class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest):
|
||||
|
||||
@property
|
||||
def use_cache(self):
|
||||
return True
|
||||
|
||||
|
||||
class WALSMatrixFactorizationTestFullBatch(WALSMatrixFactorizationTest):
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return 100
|
||||
|
||||
|
||||
class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def testDistributedWALSUnsupported(self):
|
||||
tf_config = {
|
||||
'cluster': {
|
||||
run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
|
||||
run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
|
||||
},
|
||||
'task': {
|
||||
'type': run_config_lib.TaskType.WORKER,
|
||||
'index': 1
|
||||
}
|
||||
}
|
||||
with test.mock.patch.dict('os.environ',
|
||||
{'TF_CONFIG': json.dumps(tf_config)}):
|
||||
config = run_config.RunConfig()
|
||||
self.assertEqual(config.num_worker_replicas, 2)
|
||||
with self.assertRaises(ValueError):
|
||||
self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config)
|
||||
|
||||
|
||||
class SweepHookTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -45,7 +340,7 @@ class SweepHookTest(test.TestCase):
|
||||
|
||||
def run_hook_with_indices(self, sweep_hook, row_indices, col_indices):
|
||||
with self.test_session() as sess:
|
||||
# Before run
|
||||
# Before run.
|
||||
run_context = session_run_hook.SessionRunContext(
|
||||
original_args=None, session=sess)
|
||||
sess_run_args = sweep_hook.before_run(run_context)
|
||||
@ -53,11 +348,11 @@ class SweepHookTest(test.TestCase):
|
||||
self._input_row_indices_ph: row_indices,
|
||||
self._input_col_indices_ph: col_indices
|
||||
}
|
||||
# Run
|
||||
# Run.
|
||||
run_results = sess.run(sess_run_args.fetches, feed_dict=feed_dict)
|
||||
run_values = session_run_hook.SessionRunValues(
|
||||
results=run_results, options=None, run_metadata=None)
|
||||
# After run
|
||||
# After run.
|
||||
sweep_hook.after_run(run_context, run_values)
|
||||
|
||||
def test_row_sweep(self):
|
||||
@ -74,9 +369,9 @@ class SweepHookTest(test.TestCase):
|
||||
self._col_prep_ops,
|
||||
self._init_ops)
|
||||
|
||||
# Initialize variables
|
||||
# Initialize variables.
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
# Row sweep
|
||||
# Row sweep.
|
||||
self.run_hook_with_indices(sweep_hook, [], [])
|
||||
self.assertTrue(sess.run(self._init_done),
|
||||
msg='init ops not run by the sweep_hook')
|
||||
|
@ -351,7 +351,7 @@ def _sampled_scattered_embedding_lookup(
|
||||
# No need to validate the indices since we have checked the params
|
||||
# dimensions and we know the largest id.
|
||||
result = embedding_ops.embedding_lookup(
|
||||
params, ids, partition_strategy="div", validate_indices=False)
|
||||
params, ids, partition_strategy="div")
|
||||
|
||||
return array_ops.reshape(result,
|
||||
array_ops.concat([values_shape, [dimension]], 0))
|
||||
@ -681,19 +681,17 @@ def embedding_lookup_sparse_with_distributed_aggregation(
|
||||
return embeddings
|
||||
|
||||
|
||||
def _do_gather(params, ids, validate_indices=True, name=None):
|
||||
def _do_gather(params, ids, name=None):
|
||||
"""Deals with doing gather differently for resource variables."""
|
||||
if isinstance(params, resource_variable_ops.ResourceVariable):
|
||||
return params.sparse_read(ids, name=name)
|
||||
return array_ops.gather(
|
||||
params, ids, name=name, validate_indices=validate_indices)
|
||||
return array_ops.gather(params, ids, name=name)
|
||||
|
||||
|
||||
def _embedding_lookup_with_distributed_aggregation(params,
|
||||
ids,
|
||||
partition_strategy="mod",
|
||||
name=None,
|
||||
validate_indices=True,
|
||||
max_norm=None,
|
||||
weights=None,
|
||||
idx=None,
|
||||
@ -724,8 +722,7 @@ def _embedding_lookup_with_distributed_aggregation(params,
|
||||
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
|
||||
if np == 1:
|
||||
with ops.colocate_with(params[0]):
|
||||
ret = maybe_normalize(
|
||||
_do_gather(params[0], ids, validate_indices=validate_indices))
|
||||
ret = maybe_normalize(_do_gather(params[0], ids))
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
if weights.dtype != ret.dtype:
|
||||
@ -803,9 +800,7 @@ def _embedding_lookup_with_distributed_aggregation(params,
|
||||
partitioned_result = []
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result.append(
|
||||
_do_gather(
|
||||
params[p], gather_ids[p], validate_indices=validate_indices))
|
||||
partitioned_result.append(_do_gather(params[p], gather_ids[p]))
|
||||
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
|
@ -78,8 +78,9 @@ class LinearOperatorApplyOnly(linalg.LinearOperator):
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
return math_ops.matmul(self._matrix, x, adjoint_a=adjoint)
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
return math_ops.matmul(
|
||||
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
|
||||
|
||||
class LinearOperatorTest(test.TestCase):
|
||||
|
@ -30,7 +30,6 @@ __all__ = ["LinearOperator"]
|
||||
|
||||
|
||||
# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
|
||||
# TODO(langmore) Add adjoint_x arg to apply, solve.
|
||||
class LinearOperator(object):
|
||||
"""Base class defining a [batch of] linear operator[s].
|
||||
|
||||
@ -490,16 +489,18 @@ class LinearOperator(object):
|
||||
"Expected argument to have dtype %s. Found: %s in tensor %s"
|
||||
% (self.dtype, arg.dtype, arg))
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError("_apply is not implemented.")
|
||||
|
||||
def apply(self, x, adjoint=False, name="apply"):
|
||||
def apply(self, x, adjoint=False, adjoint_arg=False, name="apply"):
|
||||
"""Transform `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
Args:
|
||||
x: `Tensor` with compatible shape and same `dtype` as `self`.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
|
||||
adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is
|
||||
the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name for this `Op.
|
||||
|
||||
Returns:
|
||||
@ -508,11 +509,12 @@ class LinearOperator(object):
|
||||
with self._name_scope(name, values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
if adjoint:
|
||||
self.shape[-2].assert_is_compatible_with(x.get_shape()[-2])
|
||||
else:
|
||||
self.shape[-1].assert_is_compatible_with(x.get_shape()[-2])
|
||||
return self._apply(x, adjoint=adjoint)
|
||||
|
||||
self_dim = -2 if adjoint else -1
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||
|
||||
return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
raise NotImplementedError("_det is not implemented.")
|
||||
@ -558,13 +560,13 @@ class LinearOperator(object):
|
||||
with self._name_scope(name):
|
||||
return self._log_abs_determinant()
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
# Since this is an exact solve method for all rhs, this will only be
|
||||
# available for non-singular (batch) operators, in particular the operator
|
||||
# must be square.
|
||||
raise NotImplementedError("_solve is not implemented.")
|
||||
|
||||
def solve(self, rhs, adjoint=False, name="solve"):
|
||||
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
||||
"""Solve `R` (batch) systems of equations exactly: `A X = rhs`.
|
||||
|
||||
Examples:
|
||||
@ -588,7 +590,9 @@ class LinearOperator(object):
|
||||
rhs: `Tensor` with same `dtype` as this operator and compatible shape.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
|
||||
of this `LinearOperator`.
|
||||
of this `LinearOperator`: `A^H X = rhs`.
|
||||
adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H`
|
||||
is the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
@ -608,11 +612,12 @@ class LinearOperator(object):
|
||||
with self._name_scope(name, values=[rhs]):
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
if adjoint:
|
||||
self.shape[-1].assert_is_compatible_with(rhs.get_shape()[-2])
|
||||
else:
|
||||
self.shape[-2].assert_is_compatible_with(rhs.get_shape()[-2])
|
||||
return self._solve(rhs, adjoint=adjoint)
|
||||
|
||||
self_dim = -1 if adjoint else -2
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
self.shape[self_dim].assert_is_compatible_with(rhs.get_shape()[arg_dim])
|
||||
|
||||
return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _to_dense(self):
|
||||
"""Generic and often inefficient implementation. Override often."""
|
||||
|
@ -225,7 +225,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
|
||||
return array_ops.concat((batch_shape, matrix_shape), 0)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
# If self.operators = [A, B], and not adjoint, then
|
||||
# apply_order_list = [B, A].
|
||||
# As a result, we return A.apply(B.apply(x))
|
||||
@ -234,8 +234,9 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
else:
|
||||
apply_order_list = list(reversed(self.operators))
|
||||
|
||||
result = x
|
||||
for operator in apply_order_list:
|
||||
result = apply_order_list[0].apply(
|
||||
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
for operator in apply_order_list[1:]:
|
||||
result = operator.apply(result, adjoint=adjoint)
|
||||
return result
|
||||
|
||||
@ -251,7 +252,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
result += operator.log_abs_determinant()
|
||||
return result
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
# TODO(langmore) Implement solve using solve_ls if some intermediate
|
||||
# operator maps to a high dimensional space.
|
||||
# In that case, an exact solve may still be possible.
|
||||
@ -264,8 +265,9 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
else:
|
||||
solve_order_list = self.operators
|
||||
|
||||
solution = rhs
|
||||
for operator in solve_order_list:
|
||||
solution = solve_order_list[0].solve(
|
||||
rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
for operator in solve_order_list[1:]:
|
||||
solution = operator.solve(solution, adjoint=adjoint)
|
||||
return solution
|
||||
|
||||
|
@ -206,8 +206,9 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
"This diagonal operator contained non-zero imaginary values. "
|
||||
" Thus it was not self-adjoint."))
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
diag_term = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
|
||||
diag_mat = array_ops.expand_dims(diag_term, -1)
|
||||
return diag_mat * x
|
||||
|
||||
@ -218,8 +219,9 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
return math_ops.reduce_sum(
|
||||
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
diag_term = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
|
||||
return rhs * inv_diag_mat
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -172,8 +173,9 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
return math_ops.matmul(self._matrix, x, adjoint_a=adjoint)
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
return math_ops.matmul(
|
||||
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
if self._is_spd:
|
||||
@ -187,7 +189,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
abs_det = math_ops.abs(self.determinant())
|
||||
return math_ops.log(abs_det)
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
if self._is_spd:
|
||||
return linalg_ops.cholesky_solve(self._chol, rhs)
|
||||
return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
|
||||
|
@ -329,8 +329,9 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
|
||||
return x + zeros
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
# Note that adjoint has no effect since this matrix is self-adjoint.
|
||||
x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
|
||||
if self._assert_proper_shapes:
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
self, x)
|
||||
@ -343,8 +344,8 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
def _log_abs_determinant(self):
|
||||
return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
return self._apply(rhs)
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
return self._apply(rhs, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _diag_part(self):
|
||||
return self._ones_diag()
|
||||
@ -616,7 +617,8 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
imag_multiplier,
|
||||
message="LinearOperator was not self-adjoint")
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
@ -634,7 +636,8 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
return self._num_rows_cast_to_real_dtype * math_ops.log(
|
||||
self._abs_multiplier)
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -213,18 +214,26 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
x = self._make_x(operator, adjoint=adjoint)
|
||||
op_apply = operator.apply(x, adjoint=adjoint)
|
||||
mat_apply = math_ops.matmul(mat, x, adjoint_a=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
|
||||
op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply],
|
||||
feed_dict=feed_dict)
|
||||
self.assertAC(op_apply_v, mat_apply_v)
|
||||
for adjoint_arg in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
x = self._make_x(operator, adjoint=adjoint)
|
||||
# If adjoint_arg, compute A X^H^H = A X.
|
||||
if adjoint_arg:
|
||||
op_apply = operator.apply(
|
||||
linear_operator_util.matrix_adjoint(x),
|
||||
adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
else:
|
||||
op_apply = operator.apply(x, adjoint=adjoint)
|
||||
mat_apply = math_ops.matmul(mat, x, adjoint_a=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(
|
||||
op_apply.get_shape(), mat_apply.get_shape())
|
||||
op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply],
|
||||
feed_dict=feed_dict)
|
||||
self.assertAC(op_apply_v, mat_apply_v)
|
||||
|
||||
def test_solve(self):
|
||||
self._skip_if_tests_to_skip_contains("solve")
|
||||
@ -232,18 +241,27 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
rhs = self._make_rhs(operator, adjoint=adjoint)
|
||||
op_solve = operator.solve(rhs, adjoint=adjoint)
|
||||
mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
|
||||
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
|
||||
feed_dict=feed_dict)
|
||||
self.assertAC(op_solve_v, mat_solve_v)
|
||||
for adjoint_arg in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
rhs = self._make_rhs(operator, adjoint=adjoint)
|
||||
# If adjoint_arg, solve A X = (rhs^H)^H = rhs.
|
||||
if adjoint_arg:
|
||||
op_solve = operator.solve(
|
||||
linear_operator_util.matrix_adjoint(rhs),
|
||||
adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
else:
|
||||
op_solve = operator.solve(
|
||||
rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(
|
||||
op_solve.get_shape(), mat_solve.get_shape())
|
||||
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
|
||||
feed_dict=feed_dict)
|
||||
self.assertAC(op_solve_v, mat_solve_v)
|
||||
|
||||
def test_add_to_tensor(self):
|
||||
self._skip_if_tests_to_skip_contains("add_to_tensor")
|
||||
|
@ -173,8 +173,9 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
self._diag,
|
||||
message="Singular operator: Diagonal contained zero values.")
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
return math_ops.matmul(self._tril, x, adjoint_a=adjoint)
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
return math_ops.matmul(
|
||||
self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
|
||||
@ -183,7 +184,8 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
return math_ops.reduce_sum(
|
||||
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
|
||||
return linalg_ops.matrix_triangular_solve(
|
||||
self._tril, rhs, lower=True, adjoint=adjoint)
|
||||
|
||||
|
@ -348,21 +348,21 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
|
||||
return array_ops.concat(
|
||||
[batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
def _apply(self, x, adjoint=False, adjoint_arg=False):
|
||||
u = self.u
|
||||
v = self.v
|
||||
l = self.base_operator
|
||||
d = self.diag_operator
|
||||
|
||||
leading_term = l.apply(x, adjoint=adjoint)
|
||||
leading_term = l.apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
if adjoint:
|
||||
uh_x = math_ops.matmul(u, x, adjoint_a=True)
|
||||
uh_x = math_ops.matmul(u, x, adjoint_a=True, adjoint_b=adjoint_arg)
|
||||
d_uh_x = d.apply(uh_x, adjoint=adjoint)
|
||||
v_d_uh_x = math_ops.matmul(v, d_uh_x)
|
||||
return leading_term + v_d_uh_x
|
||||
else:
|
||||
vh_x = math_ops.matmul(v, x, adjoint_a=True)
|
||||
vh_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=adjoint_arg)
|
||||
d_vh_x = d.apply(vh_x, adjoint=adjoint)
|
||||
u_d_vh_x = math_ops.matmul(u, d_vh_x)
|
||||
return leading_term + u_d_vh_x
|
||||
@ -398,7 +398,7 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
|
||||
|
||||
return log_abs_det_c + log_abs_det_d + log_abs_det_l
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
if self.base_operator.is_non_singular is False:
|
||||
raise ValueError(
|
||||
"Solve not implemented unless this is a perturbation of a "
|
||||
@ -421,7 +421,7 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
|
||||
u = self.u
|
||||
|
||||
# L^{-1} rhs
|
||||
linv_rhs = l.solve(rhs, adjoint=adjoint)
|
||||
linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
# V^H L^{-1} rhs
|
||||
vh_linv_rhs = math_ops.matmul(v, linv_rhs, adjoint_a=True)
|
||||
# C^{-1} V^H L^{-1} rhs
|
||||
|
@ -24,6 +24,7 @@ py_library(
|
||||
"python/training/failure_tolerator.py",
|
||||
"python/training/feeder.py",
|
||||
"python/training/hparam.py",
|
||||
"python/training/python_input.py",
|
||||
"python/training/resample.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
@ -46,8 +47,10 @@ py_library(
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:summary",
|
||||
@ -243,6 +246,26 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "python_input_test",
|
||||
size = "medium",
|
||||
srcs = ["python/training/python_input_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "evaluation_test",
|
||||
size = "small",
|
||||
|
@ -35,6 +35,7 @@ See @{$python/contrib.training} guide.
|
||||
@@HParams
|
||||
@@HParamDef
|
||||
@@parse_values
|
||||
@@python_input
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -54,6 +55,7 @@ from tensorflow.contrib.training.python.training.evaluation import wait_for_new_
|
||||
from tensorflow.contrib.training.python.training.failure_tolerator import *
|
||||
from tensorflow.contrib.training.python.training.feeder import *
|
||||
from tensorflow.contrib.training.python.training.hparam import *
|
||||
from tensorflow.contrib.training.python.training.python_input import python_input
|
||||
from tensorflow.contrib.training.python.training.resample import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
|
@ -251,10 +251,16 @@ def bucket(tensors,
|
||||
else:
|
||||
which_dequeue = lambda q: q.dequeue_many
|
||||
|
||||
def make_list(t):
|
||||
if isinstance(t, (list, tuple)):
|
||||
return t
|
||||
else:
|
||||
return [t]
|
||||
|
||||
enqueues_to_top = [
|
||||
top_queue.enqueue(
|
||||
[constant_op.constant(i)] + which_dequeue(q)(
|
||||
bs, name="read_bucket_%d" % i),
|
||||
[constant_op.constant(i)] + make_list(which_dequeue(q)(
|
||||
bs, name="read_bucket_%d" % i)),
|
||||
name="enqueue_from_bucket_%d" % i)
|
||||
for i, (q, bs) in enumerate(zip(bucket_queues, batch_size))
|
||||
]
|
||||
@ -282,6 +288,8 @@ def bucket(tensors,
|
||||
dequeued = top_queue.dequeue(name="dequeue_top")
|
||||
which_bucket_dequeued = dequeued[0]
|
||||
dequeued = dequeued[1:]
|
||||
if len(dequeued) == 1:
|
||||
dequeued = dequeued[0]
|
||||
dequeued = _restore_sparse_tensors(dequeued, sparse_info)
|
||||
return (which_bucket_dequeued, _as_original_type(tensors, dequeued))
|
||||
|
||||
|
178
tensorflow/contrib/training/python/training/python_input.py
Normal file
178
tensorflow/contrib/training/python/training/python_input.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operations for asynchronously reading data from python into queues.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
|
||||
|
||||
def _process_yielded_dict(feature_values, keys, features, dtypes, shapes):
|
||||
"""Read feature_values from the generator and emit a proper output dict."""
|
||||
if not isinstance(feature_values, dict):
|
||||
raise TypeError("generator must return dict, saw: %s" % feature_values)
|
||||
|
||||
processed_values = {}
|
||||
for pk in keys:
|
||||
if feature_values.get(pk, None) is not None:
|
||||
processed_values[pk] = np.asarray(
|
||||
feature_values[pk], dtype=dtypes[pk].as_numpy_dtype)
|
||||
check_shape = tensor_shape.TensorShape(processed_values[pk].shape)
|
||||
if not shapes[pk].is_compatible_with(check_shape):
|
||||
raise ValueError(
|
||||
"Feature '%s' has shape %s that is incompatible with declared "
|
||||
"shape: %s" % (pk, shapes[pk], check_shape))
|
||||
continue
|
||||
if isinstance(features[pk], parsing_ops.FixedLenFeature):
|
||||
if features[pk].default_value is not None:
|
||||
processed_values[pk] = np.asarray(
|
||||
features[pk].default_value, dtype=dtypes[pk].as_numpy_dtype)
|
||||
elif isinstance(features[pk], parsing_ops.FixedLenSequenceFeature):
|
||||
processed_values[pk] = np.empty(
|
||||
[0] + features[pk].shape.aslist(), dtype=dtypes[pk].as_numpy_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected generator to return key '%s' with non-empty value" % pk)
|
||||
|
||||
return processed_values
|
||||
|
||||
|
||||
def python_input(generator, features, name=None):
|
||||
"""Easily feed data from a python generator into TensorFlow queues.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
for i in range(3):
|
||||
yield {"value": i}
|
||||
|
||||
features = {
|
||||
"value": tf.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
|
||||
tensor_dict = tf.contrib.training.python_input(generator, features)
|
||||
batched_dict = tf.train.batch(
|
||||
tensor_dict, batch_size=2, allow_smaller_final_batch=True)
|
||||
|
||||
s = tf.Session()
|
||||
tf.train.start_queue_runners()
|
||||
|
||||
batch1 = s.run(batched_dict) # returns {"value": np.array([0, 1])}
|
||||
batch2 = s.run(batched_dict) # returns {"value": np.array([2])}
|
||||
s.run(batched_dict) # error: Queue is closed (generator finished at i==3)
|
||||
```
|
||||
|
||||
Args:
|
||||
generator: A python generator that takes no arguments, and yields dicts
|
||||
containing a single minibatch entry one at a time.
|
||||
features: A python `dict` mapping keys expected from the generator to
|
||||
instances of `tf.FixedLenFeature`, or `tf.FixedLenSequenceFeature`.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A dict mapping keys of the `features` dict to `Tensor` objects.
|
||||
These `Tensor` objects are outputs of a queue that is fed by `generator`.
|
||||
|
||||
Raises:
|
||||
TypeError: If generator is not callable or features is not a dict.
|
||||
TypeError: If any of features' values are not a Feature object.
|
||||
NotImplementedError: If any of features' values are instances of
|
||||
`SparseFeature` or `VarLenFeature` (these are not currently supported).
|
||||
ValueError: If any FixedLenSequenceFeatures contain a default value
|
||||
(this field is not supported).
|
||||
ValueError: if any FixedLenSequenceFeatures have allow_missing=False
|
||||
(this field is not supported).
|
||||
"""
|
||||
if not callable(generator):
|
||||
raise TypeError("generator must be callable, saw: %s" % generator)
|
||||
if not isinstance(features, dict):
|
||||
raise TypeError("features must be a dict, saw: %s"
|
||||
% type(features).__name__)
|
||||
|
||||
with ops.name_scope(name, "python_input"):
|
||||
shapes = {}
|
||||
dtypes = {}
|
||||
for k, v in features.items():
|
||||
if isinstance(v, parsing_ops.FixedLenFeature):
|
||||
if v.default_value is not None:
|
||||
value = ops.convert_to_tensor(v.default_value, dtype=v.dtype, name=k)
|
||||
shapes[k] = value.shape
|
||||
dtypes[k] = value.dtype
|
||||
else:
|
||||
tensor_shape.TensorShape(v.shape).assert_is_fully_defined()
|
||||
shapes[k] = tensor_shape.TensorShape(v.shape)
|
||||
dtypes[k] = v.dtype
|
||||
elif isinstance(v, parsing_ops.VarLenFeature):
|
||||
raise NotImplementedError("VarLenFeature not supported")
|
||||
elif isinstance(v, parsing_ops.SparseFeature):
|
||||
raise NotImplementedError("SparseFeature not supported")
|
||||
elif isinstance(v, parsing_ops.FixedLenSequenceFeature):
|
||||
if v.default_value is not None:
|
||||
raise ValueError("FixedLenSequenceFeature with default value not "
|
||||
"supported")
|
||||
if not v.allow_missing:
|
||||
raise ValueError("FixedLenSequenceFeature with allow_missing=False "
|
||||
"not supported")
|
||||
tensor_shape.TensorShape(v.shape).assert_is_fully_defined()
|
||||
shapes[k] = tensor_shape.TensorShape([None]).concatenate(v.shape)
|
||||
dtypes[k] = v.dtype
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected value for features key '%s' to be one of "
|
||||
"FixedLenFeature, VarLenFeature, SparseFeature, or "
|
||||
"FixedLenSequenceFeature. Got: %s" % (k, v))
|
||||
|
||||
keys = list(shapes.keys())
|
||||
dtypes_list = [dtypes[pk] for pk in keys]
|
||||
|
||||
counter = [0]
|
||||
lock = threading.Lock()
|
||||
iterator = iter(generator())
|
||||
|
||||
def generator_iter():
|
||||
"""Iterate through generator output and return np.arrays to py_func."""
|
||||
with lock:
|
||||
try:
|
||||
feature_values = next(iterator)
|
||||
counter[0] += 1
|
||||
except StopIteration as e:
|
||||
raise StopIteration("Iteration finished. Processed %d entries (%s)"
|
||||
% (counter[0], e))
|
||||
|
||||
processed_dict = _process_yielded_dict(
|
||||
feature_values, keys, features, dtypes, shapes)
|
||||
return [processed_dict[pk] for pk in keys]
|
||||
|
||||
generator_pyfunc_values = script_ops.py_func(
|
||||
generator_iter, inp=[], Tout=dtypes_list, stateful=True)
|
||||
|
||||
pyfunc_input = {k: v for (k, v) in zip(keys, generator_pyfunc_values)}
|
||||
for k, v in shapes.items():
|
||||
pyfunc_input[k].set_shape(v)
|
||||
|
||||
return pyfunc_input
|
||||
|
||||
|
||||
__all__ = ["python_input"]
|
191
tensorflow/contrib/training/python/training/python_input_test.py
Normal file
191
tensorflow/contrib/training/python/training/python_input_test.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.training.python_input."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.contrib.training.python.training import bucket_ops
|
||||
from tensorflow.contrib.training.python.training import python_input
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import input as core_input
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
class PythonInputTest(test.TestCase):
|
||||
|
||||
def testGenerator(self):
|
||||
def simple_generator():
|
||||
for i in range(2):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
self.assertEqual(["value"], tensors.keys())
|
||||
self.assertEqual(dtypes.int32, tensors["value"].dtype)
|
||||
self.assertEqual((), tensors["value"].shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual({"value": 0}, sess.run(tensors))
|
||||
self.assertEqual({"value": 1}, sess.run(tensors))
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
|
||||
def testInvalidGenerator(self):
|
||||
generator1 = lambda: iter([{"value": "a"}])
|
||||
int_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors1 = python_input.python_input(generator1, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("invalid literal"):
|
||||
# Can't convert a string to an integer
|
||||
sess.run(tensors1)
|
||||
|
||||
generator2 = lambda: iter([None])
|
||||
tensors2 = python_input.python_input(generator2, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("generator must return dict"):
|
||||
sess.run(tensors2)
|
||||
|
||||
generator3 = lambda: iter([{"value": [1, 2]}])
|
||||
tensors3 = python_input.python_input(generator3, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("incompatible with declared shape"):
|
||||
sess.run(tensors3)
|
||||
|
||||
def testGeneratorWorksWithBatching(self):
|
||||
def simple_generator():
|
||||
for i in range(5):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
|
||||
# Request batches of size 4 at a time, the final batch may be smaller.
|
||||
batched_tensors = core_input.batch(tensors, batch_size=4,
|
||||
allow_smaller_final_batch=True)
|
||||
|
||||
self.assertEqual(["value"], batched_tensors.keys())
|
||||
self.assertEqual(dtypes.int32, batched_tensors["value"].dtype)
|
||||
self.assertEqual([None], batched_tensors["value"].shape.as_list())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# The generator emits 5 items total. The first 4 are returned in
|
||||
# the first session run; the final one is returned in the
|
||||
# second. This works because allow_smaller_final_batch=True.
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
r1 = sess.run(batched_tensors)
|
||||
r2 = sess.run(batched_tensors)
|
||||
self.assertAllEqual([0, 1, 2, 3], r1["value"])
|
||||
self.assertEqual([4], r2["value"])
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def testGeneratorWorksWithManyBatchingThreads(self):
|
||||
def simple_generator():
|
||||
for i in range(5000):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
|
||||
# Request batches of size 20 at a time, the final batch may be smaller.
|
||||
_, batched_tensors = bucket_ops.bucket(
|
||||
tensors, which_bucket=tensors["value"] % 5,
|
||||
batch_size=20, num_buckets=5, num_threads=7, capacity=17,
|
||||
allow_smaller_final_batch=True)
|
||||
|
||||
self.assertEqual(["value"], batched_tensors.keys())
|
||||
self.assertEqual(dtypes.int32, batched_tensors["value"].dtype)
|
||||
self.assertEqual([None], batched_tensors["value"].shape.as_list())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# The generator emits 5 items total. The first 4 are returned in
|
||||
# the first session run; the final one is returned in the
|
||||
# second. This works because allow_smaller_final_batch=True.
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
results = []
|
||||
while True:
|
||||
try:
|
||||
r = sess.run(batched_tensors)
|
||||
results.extend(r["value"].tolist())
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
self.assertEqual(sorted(results),
|
||||
list(range(5000)))
|
||||
|
||||
def testVaryingFieldsInGenerator(self):
|
||||
def simple_generator():
|
||||
for i in range(2):
|
||||
yield {"value": i,
|
||||
"seqlen_value": np.ones((i, 1))}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32),
|
||||
"seqlen_value": parsing_ops.FixedLenSequenceFeature(
|
||||
shape=[1], dtype=dtypes.float32, allow_missing=True),
|
||||
"empty_value": parsing_ops.FixedLenFeature(
|
||||
default_value=[-1, -2], dtype=dtypes.int32, shape=[2])
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
self.assertEqual(
|
||||
set(["value", "seqlen_value", "empty_value"]), set(tensors.keys()))
|
||||
self.assertEqual(dtypes.int32, tensors["value"].dtype)
|
||||
self.assertEqual((), tensors["value"].shape)
|
||||
self.assertEqual(dtypes.float32, tensors["seqlen_value"].dtype)
|
||||
self.assertEqual([None, 1], tensors["seqlen_value"].shape.as_list())
|
||||
self.assertEqual(dtypes.int32, tensors["empty_value"].dtype)
|
||||
self.assertEqual([2], tensors["empty_value"].shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
r1 = sess.run(tensors)
|
||||
self.assertAllEqual(0, r1["value"])
|
||||
self.assertAllEqual(np.ones((0, 1)), r1["seqlen_value"])
|
||||
self.assertAllEqual([-1, -2], r1["empty_value"])
|
||||
|
||||
r2 = sess.run(tensors)
|
||||
self.assertAllEqual(1, r2["value"])
|
||||
self.assertAllEqual([[1]], r2["seqlen_value"])
|
||||
self.assertAllEqual([-1, -2], r2["empty_value"])
|
||||
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -19,6 +19,7 @@ filegroup(
|
||||
srcs = [
|
||||
"devices.cc",
|
||||
"devices.h",
|
||||
"grappler_item.cc",
|
||||
"grappler_item.h",
|
||||
"utils.cc",
|
||||
"utils.h",
|
||||
|
@ -17,6 +17,7 @@ filegroup(
|
||||
srcs = glob(
|
||||
[
|
||||
"*_optimizer.*",
|
||||
"auto_parallel.*",
|
||||
"constant_folding.*",
|
||||
"model_pruner.*",
|
||||
"graph_rewriter.*",
|
||||
@ -210,6 +211,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":auto_parallel",
|
||||
":constant_folding",
|
||||
":graph_optimizer",
|
||||
":layout_optimizer",
|
||||
|
@ -25,7 +25,9 @@ namespace grappler {
|
||||
// Automatically parallelize a graph by splitting in the batch dimension.
|
||||
class AutoParallel : public GraphOptimizer {
|
||||
public:
|
||||
AutoParallel(int num_replicas) : num_replicas_(num_replicas) {}
|
||||
AutoParallel(int num_replicas) : num_replicas_(num_replicas) {
|
||||
CHECK(num_replicas_ >= 2);
|
||||
}
|
||||
~AutoParallel() override {}
|
||||
|
||||
string name() const override { return "autoparallel"; };
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
@ -41,6 +42,10 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
if (optimizer == "memory") {
|
||||
graph_optimizer.reset(new MemoryOptimizer());
|
||||
}
|
||||
if (optimizer == "autoparallel") {
|
||||
graph_optimizer.reset(
|
||||
new AutoParallel(cfg_.auto_parallel().num_replicas()));
|
||||
}
|
||||
return graph_optimizer;
|
||||
}
|
||||
|
||||
@ -63,11 +68,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer()));
|
||||
}
|
||||
if (cfg_.auto_parallel().enable()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
new AutoParallel(cfg_.auto_parallel().num_replicas())));
|
||||
}
|
||||
} else {
|
||||
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout",
|
||||
"memory"};
|
||||
std::set<string> available_optimizers = {"pruning", "constfold", "layout",
|
||||
"memory", "autoparallel"};
|
||||
for (const auto& optimizer : cfg_.optimizers()) {
|
||||
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
|
||||
if (available_optimizers.find(optimizer) != available_optimizers.end()) {
|
||||
optimizers.push_back(NewOptimizer(optimizer));
|
||||
}
|
||||
}
|
||||
@ -102,7 +111,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
||||
return cfg.optimize_tensor_layout();
|
||||
return cfg.optimize_tensor_layout() || cfg.constant_folding() ||
|
||||
cfg.auto_parallel().enable() || !cfg.optimizers().empty();
|
||||
}
|
||||
|
||||
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
|
||||
|
@ -161,9 +161,11 @@ TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
|
||||
#undef REGISTER_ADDN_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_ADDN(Eigen::half, GPU);
|
||||
REGISTER_ADDN(float, GPU);
|
||||
REGISTER_ADDN(double, GPU);
|
||||
#define REGISTER_ADDN_GPU(type) REGISTER_ADDN(type, GPU)
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ADDN_GPU);
|
||||
TF_CALL_complex64(REGISTER_ADDN_GPU);
|
||||
TF_CALL_complex128(REGISTER_ADDN_GPU);
|
||||
#undef REGISTER_ADDN_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -154,6 +154,8 @@ struct Add9Functor<GPUDevice, T> {
|
||||
template struct functor::Add9Functor<GPUDevice, type>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
|
||||
TF_CALL_complex64(REGISTER_FUNCTORS);
|
||||
TF_CALL_complex128(REGISTER_FUNCTORS);
|
||||
|
||||
#undef REGISTER_FUNCTORS
|
||||
|
||||
|
@ -355,6 +355,24 @@ TEST_F(QuantizationUtilsTest, AvoidBias) {
|
||||
const int back_to_int = FloatToQuantized<quint8>(as_float, 0.0f, 2.0f);
|
||||
EXPECT_EQ(i, back_to_int);
|
||||
}
|
||||
|
||||
// All perfectly representable floats should survive quantization, even
|
||||
// if we pick a range where min is not itself perfectly representable.
|
||||
const float min = -0.1375f;
|
||||
const float max = 1.1385f;
|
||||
const float step_size = (max - min) / 255.0f;
|
||||
const float tolerance = step_size / 1000.0f;
|
||||
// This is the smallest perfectly representable float in the range.
|
||||
float first_float = ceil(min / step_size) * step_size;
|
||||
// TODO(ahentz): The current version always incur a small error, which we
|
||||
// need to account for. We should fix QuantizedToFloat<> to remove this bias.
|
||||
const float expected_error = first_float - min;
|
||||
ASSERT_GT(expected_error, tolerance);
|
||||
for (float f = first_float; f <= max; f += step_size) {
|
||||
const int as_int = FloatToQuantized<quint8>(f, min, max);
|
||||
const float back_to_float = QuantizedToFloat<quint8>(as_int, min, max);
|
||||
EXPECT_NEAR(f, back_to_float + expected_error, tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizationUtilsTest, RequantizeInNewRange) {
|
||||
|
@ -488,7 +488,7 @@ REGISTER_OP("SplitV")
|
||||
ShapeHandle output_shape;
|
||||
const Tensor* size_splits = c->input_tensor(1);
|
||||
if (rank == InferenceContext::kUnknownRank) {
|
||||
// If the rank of input tensor is unknown, then return unkown shapes.
|
||||
// If the rank of input tensor is unknown, then return unknown shapes.
|
||||
output_shape = c->UnknownShape();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
c->set_output(i, output_shape);
|
||||
@ -497,7 +497,7 @@ REGISTER_OP("SplitV")
|
||||
// Throw error if input is a scalar.
|
||||
return errors::InvalidArgument("Can't split scalars");
|
||||
} else if (size_splits == nullptr || !c->ValueKnown(split_dimension)) {
|
||||
// If split dimension or tensor containing the split sizes is unkown,
|
||||
// If split dimension or tensor containing the split sizes is unknown,
|
||||
// then return unknown shapes of same rank as input.
|
||||
output_shape = c->UnknownShapeOfRank(rank);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
@ -1328,8 +1328,8 @@ this operation will permute `params` accordingly.
|
||||
|
||||
`validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in
|
||||
`indices` are always validated to be within range. If assigned to GPU,
|
||||
out-of-bound indices result in unspecified behavior (currently the result is
|
||||
`0`, but this may become an error in the future).
|
||||
out-of-bound indices result in safe but unspecified behavior, which may include
|
||||
raising an error.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/Gather.png" alt>
|
||||
|
@ -6,6 +6,11 @@ option java_outer_classname = "RewriterConfigProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
message AutoParallelOptions {
|
||||
bool enable = 1;
|
||||
int32 num_replicas = 2;
|
||||
}
|
||||
|
||||
message RewriterConfig {
|
||||
bool optimize_tensor_layout = 1;
|
||||
bool disable_model_pruning = 2;
|
||||
@ -19,6 +24,8 @@ message RewriterConfig {
|
||||
}
|
||||
MemOptType memory_optimization = 4;
|
||||
|
||||
AutoParallelOptions auto_parallel = 5;
|
||||
|
||||
// If non-empty, will use this as an alternative way to specify a list of
|
||||
// optimizations to turn on and the order of the optimizations.
|
||||
repeated string optimizers = 100;
|
||||
|
@ -57,14 +57,17 @@ have the same device assignment.
|
||||
with tf.device('/cpu:0'):
|
||||
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
|
||||
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
|
||||
c = tf.matmul(a, b)
|
||||
c = tf.matmul(a, b)
|
||||
# Creates a session with log_device_placement set to True.
|
||||
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
|
||||
# Runs the op.
|
||||
print(sess.run(c))
|
||||
```
|
||||
|
||||
You will see that now `a` and `b` are assigned to `cpu:0`.
|
||||
You will see that now `a` and `b` are assigned to `cpu:0`. Since a device was
|
||||
not explicitly specified for the `MatMul` operation, the TensorFlow runtime will
|
||||
choose one based on the operation and available devices (`gpu:0` in this
|
||||
example) and automatically copy tensors between devices if required.
|
||||
|
||||
```
|
||||
Device mapping:
|
||||
|
@ -132,6 +132,7 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
# documentation, or remove.
|
||||
_allowed_symbols = [
|
||||
'AttrValue',
|
||||
'AutoParallelOptions',
|
||||
'ConfigProto',
|
||||
'DeviceSpec',
|
||||
'Event',
|
||||
|
@ -124,7 +124,7 @@ class ScrollBar(object):
|
||||
raise ValueError("Insufficient height for ScrollBar (%d)" %
|
||||
(self._max_y - self._min_y + 1))
|
||||
|
||||
def _block_y(self):
|
||||
def _block_y(self, screen_coord_sys=False):
|
||||
"""Get the 0-based y coordinate of the scroll block.
|
||||
|
||||
This y coordinate takes into account the presence of the UP and DN buttons
|
||||
@ -132,9 +132,13 @@ class ScrollBar(object):
|
||||
location, the return value will be 1; at the bottom location, the return
|
||||
value will be self._scroll_bar_height - 2.
|
||||
|
||||
Args:
|
||||
screen_coord_sys: (`bool`) whether the return value will be in the
|
||||
screen coordinate system.
|
||||
|
||||
Returns:
|
||||
(int) 0-based y coordinate of the scroll block, in the ScrollBar
|
||||
coordinate system, i.e., not the screen coordinate system. For example,
|
||||
coordinate system by default. For example,
|
||||
when scroll position is at the top, this return value will be 1 (not 0,
|
||||
because of the presence of the UP button). When scroll position is at
|
||||
the bottom, this return value will be self._scroll_bar_height - 2
|
||||
@ -142,8 +146,10 @@ class ScrollBar(object):
|
||||
button).
|
||||
"""
|
||||
|
||||
return int(float(self._scroll_position) / (self._output_num_rows - 1) *
|
||||
(self._scroll_bar_height - 3)) + 1
|
||||
rel_block_y = int(
|
||||
float(self._scroll_position) / (self._output_num_rows - 1) *
|
||||
(self._scroll_bar_height - 3)) + 1
|
||||
return rel_block_y + self._min_y if screen_coord_sys else rel_block_y
|
||||
|
||||
def layout(self):
|
||||
"""Get the RichTextLines layout of the scroll bar.
|
||||
@ -192,9 +198,11 @@ class ScrollBar(object):
|
||||
return _SCROLL_UP_A_LINE
|
||||
elif mouse_y == self._max_y:
|
||||
return _SCROLL_DOWN_A_LINE
|
||||
elif mouse_y > self._block_y() and mouse_y < self._max_y:
|
||||
elif (mouse_y > self._block_y(screen_coord_sys=True) and
|
||||
mouse_y < self._max_y):
|
||||
return _SCROLL_DOWN
|
||||
elif mouse_y < self._block_y() and mouse_y > self._min_y:
|
||||
elif (mouse_y < self._block_y(screen_coord_sys=True) and
|
||||
mouse_y > self._min_y):
|
||||
return _SCROLL_UP
|
||||
else:
|
||||
return None
|
||||
@ -505,7 +513,7 @@ class CursesUI(base_ui.BaseUI):
|
||||
def get_help(self):
|
||||
return self._command_handler_registry.get_help()
|
||||
|
||||
def _screen_create_command_textbox(self, existing_command):
|
||||
def _screen_create_command_textbox(self, existing_command=None):
|
||||
"""Create command textbox on screen.
|
||||
|
||||
Args:
|
||||
@ -839,6 +847,7 @@ class CursesUI(base_ui.BaseUI):
|
||||
else:
|
||||
command = self._fetch_hyperlink_command(mouse_x, mouse_y)
|
||||
if command:
|
||||
self._screen_create_command_textbox()
|
||||
exit_token = self._dispatch_command(command)
|
||||
if exit_token is not None:
|
||||
raise debugger_cli_common.CommandLineExit(exit_token=exit_token)
|
||||
@ -898,13 +907,14 @@ class CursesUI(base_ui.BaseUI):
|
||||
"""Automatically key in a command to the command Textbox.
|
||||
|
||||
Args:
|
||||
command: The command, as a string.
|
||||
command: The command, as a string or None.
|
||||
erase_existing: (bool) whether existing text (if any) is to be erased
|
||||
first.
|
||||
"""
|
||||
if erase_existing:
|
||||
self._erase_existing_command()
|
||||
|
||||
command = command or ""
|
||||
for c in command:
|
||||
self._command_textbox.do_command(ord(c))
|
||||
|
||||
@ -1227,9 +1237,9 @@ class CursesUI(base_ui.BaseUI):
|
||||
|
||||
self._scroll_bar = ScrollBar(
|
||||
self._max_x - 2,
|
||||
2,
|
||||
3,
|
||||
self._max_x - 1,
|
||||
self._output_num_rows,
|
||||
self._output_num_rows + 1,
|
||||
self._output_pad_row,
|
||||
self._output_pad_height - self._output_pad_screen_height)
|
||||
|
||||
|
@ -113,7 +113,7 @@ class MockCursesUI(curses_ui.CursesUI):
|
||||
def _screen_create_command_window(self):
|
||||
pass
|
||||
|
||||
def _screen_create_command_textbox(self, existing_command):
|
||||
def _screen_create_command_textbox(self, existing_command=None):
|
||||
"""Override to insert observer of existing commands.
|
||||
|
||||
Used in testing of history navigation and tab completion.
|
||||
@ -1646,6 +1646,25 @@ class ScrollBarTest(test_util.TensorFlowTestCase):
|
||||
scroll_bar.get_click_command(7))
|
||||
self.assertIsNone(scroll_bar.get_click_command(8))
|
||||
|
||||
def testClickCommandsAreCorrectForScrollBarNotAtZeroMinY(self):
|
||||
scroll_bar = curses_ui.ScrollBar(0, 5, 1, 12, 10, 20)
|
||||
self.assertIsNone(scroll_bar.get_click_command(0))
|
||||
self.assertIsNone(scroll_bar.get_click_command(4))
|
||||
self.assertEqual(curses_ui._SCROLL_UP_A_LINE,
|
||||
scroll_bar.get_click_command(5))
|
||||
self.assertEqual(curses_ui._SCROLL_UP,
|
||||
scroll_bar.get_click_command(6))
|
||||
self.assertEqual(curses_ui._SCROLL_UP,
|
||||
scroll_bar.get_click_command(7))
|
||||
self.assertIsNone(scroll_bar.get_click_command(8))
|
||||
self.assertEqual(curses_ui._SCROLL_DOWN,
|
||||
scroll_bar.get_click_command(10))
|
||||
self.assertEqual(curses_ui._SCROLL_DOWN,
|
||||
scroll_bar.get_click_command(11))
|
||||
self.assertEqual(curses_ui._SCROLL_DOWN_A_LINE,
|
||||
scroll_bar.get_click_command(12))
|
||||
self.assertIsNone(scroll_bar.get_click_command(13))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -51,6 +51,7 @@ class ModeKeys(object):
|
||||
class MetricKeys(object):
|
||||
"""Metric key strings."""
|
||||
LOSS = 'loss'
|
||||
AVERAGE_LOSS = 'average_loss'
|
||||
|
||||
|
||||
class EstimatorSpec(
|
||||
|
@ -941,6 +941,19 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "aggregate_ops_test",
|
||||
size = "small",
|
||||
srcs = ["aggregate_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "argmax_op_test",
|
||||
size = "small",
|
||||
|
79
tensorflow/python/kernel_tests/aggregate_ops_test.py
Normal file
79
tensorflow/python/kernel_tests/aggregate_ops_test.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for aggregate_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AddNTest(test.TestCase):
|
||||
# AddN special-cases adding the first M inputs to make (N - M) divisible by 8,
|
||||
# after which it adds the remaining (N - M) tensors 8 at a time in a loop.
|
||||
# Test N in [1, 10] so we check each special-case from 1 to 9 and one
|
||||
# iteration of the loop.
|
||||
_MAX_N = 10
|
||||
|
||||
def _supported_types(self):
|
||||
if test.is_gpu_available():
|
||||
return [dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
|
||||
dtypes.complex128]
|
||||
return [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
|
||||
dtypes.complex128]
|
||||
|
||||
def _buildData(self, shape, dtype):
|
||||
data = np.random.randn(*shape).astype(dtype.as_numpy_dtype)
|
||||
# For complex types, add an index-dependent imaginary component so we can
|
||||
# tell we got the right value.
|
||||
if dtype.is_complex:
|
||||
return data + 10j * data
|
||||
return data
|
||||
|
||||
def testAddN(self):
|
||||
np.random.seed(12345)
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for dtype in self._supported_types():
|
||||
for count in range(1, self._MAX_N + 1):
|
||||
data = [self._buildData((2, 2), dtype) for _ in range(count)]
|
||||
actual = sess.run(math_ops.add_n(data))
|
||||
expected = np.sum(np.vstack(
|
||||
[np.expand_dims(d, 0) for d in data]), axis=0)
|
||||
tol = 5e-3 if dtype == dtypes.float16 else 5e-7
|
||||
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
|
||||
|
||||
def testUnknownShapes(self):
|
||||
np.random.seed(12345)
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for dtype in self._supported_types():
|
||||
data = self._buildData((2, 2), dtype)
|
||||
for count in range(1, self._MAX_N + 1):
|
||||
data_ph = array_ops.placeholder(dtype=dtype)
|
||||
actual = sess.run(math_ops.add_n([data_ph] * count), {data_ph: data})
|
||||
expected = np.sum(np.vstack([np.expand_dims(data, 0)] * count),
|
||||
axis=0)
|
||||
tol = 5e-3 if dtype == dtypes.float16 else 5e-7
|
||||
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -50,8 +50,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -426,10 +424,11 @@ def merge(inputs, name=None):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _convert_tensorarrays_to_flows(tensors_or_tensor_arrays):
|
||||
return [ta.flow if isinstance(ta, tensor_array_ops.TensorArray)
|
||||
else ta
|
||||
for ta in tensors_or_tensor_arrays]
|
||||
def _convert_tensorarray_to_flow(tensor_or_tensor_array):
|
||||
if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
|
||||
return tensor_or_tensor_array.flow
|
||||
else:
|
||||
return tensor_or_tensor_array
|
||||
|
||||
|
||||
def _make_tensor_array(ta, t_or_flow):
|
||||
@ -1637,63 +1636,77 @@ class CondContext(ControlFlowContext):
|
||||
real_val = external_val
|
||||
return real_val
|
||||
|
||||
def _BuildCondTensor(self, v):
|
||||
if isinstance(v, ops.Operation):
|
||||
# Use pivot as the proxy for this op.
|
||||
return with_dependencies([v], self._pivot)
|
||||
elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
|
||||
values = self._ProcessOutputTensor(v.values)
|
||||
indices = self._ProcessOutputTensor(v.indices)
|
||||
if isinstance(v, ops.IndexedSlices):
|
||||
dense_shape = v.dense_shape
|
||||
if dense_shape is not None:
|
||||
dense_shape = self._ProcessOutputTensor(dense_shape)
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
else:
|
||||
dense_shape = self._ProcessOutputTensor(v.dense_shape)
|
||||
return sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
v = nest.map_structure(_convert_tensorarray_to_flow, v)
|
||||
return self._ProcessOutputTensor(ops.convert_to_tensor(v))
|
||||
|
||||
def BuildCondBranch(self, fn):
|
||||
"""Add the subgraph defined by fn() to the graph."""
|
||||
r = fn()
|
||||
original_r = r
|
||||
result = []
|
||||
if r is not None:
|
||||
if not isinstance(r, list) and not isinstance(r, _basetuple):
|
||||
r = [r]
|
||||
original_r = [original_r]
|
||||
r = _convert_tensorarrays_to_flows(r)
|
||||
for v in r:
|
||||
real_v = v
|
||||
if isinstance(v, ops.Operation):
|
||||
# Use pivot as the proxy for this op.
|
||||
real_v = with_dependencies([v], self._pivot)
|
||||
else:
|
||||
if isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
|
||||
values = self._ProcessOutputTensor(v.values)
|
||||
indices = self._ProcessOutputTensor(v.indices)
|
||||
if isinstance(v, ops.IndexedSlices):
|
||||
dense_shape = v.dense_shape
|
||||
if dense_shape is not None:
|
||||
dense_shape = self._ProcessOutputTensor(dense_shape)
|
||||
real_v = ops.IndexedSlices(values, indices, dense_shape)
|
||||
else:
|
||||
dense_shape = self._ProcessOutputTensor(v.dense_shape)
|
||||
real_v = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
real_v = self._ProcessOutputTensor(v)
|
||||
result.append(real_v)
|
||||
return original_r, result
|
||||
original_result = fn()
|
||||
if original_result is None:
|
||||
return None, None
|
||||
|
||||
result = nest.map_structure(self._BuildCondTensor, original_result)
|
||||
if not isinstance(result, (list, _basetuple)):
|
||||
result = [result]
|
||||
return original_result, result
|
||||
|
||||
|
||||
def cond(pred, fn1, fn2, name=None):
|
||||
"""Return either fn1() or fn2() based on the boolean predicate `pred`.
|
||||
def _UnpackIfSingleton(res):
|
||||
if isinstance(res, (list, _basetuple)) and len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
"""Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
|
||||
|
||||
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
|
||||
the same non-zero number and type of outputs.
|
||||
|
||||
Note that the conditional execution applies only to the operations defined in
|
||||
fn1 and fn2. Consider the following simple program:
|
||||
`fn1` and `fn2`. Consider the following simple program:
|
||||
|
||||
```python
|
||||
z = tf.multiply(a, b)
|
||||
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
|
||||
```
|
||||
|
||||
If x < y, the `tf.add` operation will be executed and `tf.square`
|
||||
operation will not be executed. Since z is needed for at least one
|
||||
branch of the cond, the `tf.multiply` operation is always executed, unconditionally.
|
||||
If `x < y`, the `tf.add` operation will be executed and `tf.square`
|
||||
operation will not be executed. Since `z` is needed for at least one
|
||||
branch of the `cond`, the `tf.multiply` operation is always executed,
|
||||
unconditionally.
|
||||
Although this behavior is consistent with the dataflow model of TensorFlow,
|
||||
it has occasionally surprised some users who expected a lazier semantics.
|
||||
|
||||
`tf.cond` supports nested structures as implemented in
|
||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
||||
behavior is disabled by passing `strict=True`.
|
||||
|
||||
Args:
|
||||
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
|
||||
fn1: The callable to be performed if pred is true.
|
||||
fn2: The callable to be performed if pred is false.
|
||||
strict: A boolean that enables/disables 'strict' mode; see above.
|
||||
name: Optional name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
@ -1738,23 +1751,43 @@ def cond(pred, fn1, fn2, name=None):
|
||||
# Build the graph for the true branch in a new context.
|
||||
context_t = CondContext(pred, pivot_1, branch=1)
|
||||
context_t.Enter()
|
||||
orig_res, res_t = context_t.BuildCondBranch(fn1)
|
||||
orig_res_t, res_t = context_t.BuildCondBranch(fn1)
|
||||
if orig_res_t is None:
|
||||
raise ValueError("fn1 must have a return value.")
|
||||
context_t.ExitResult(res_t)
|
||||
context_t.Exit()
|
||||
|
||||
# Build the graph for the false branch in a new context.
|
||||
context_f = CondContext(pred, pivot_2, branch=0)
|
||||
context_f.Enter()
|
||||
_, res_f = context_f.BuildCondBranch(fn2)
|
||||
orig_res_f, res_f = context_f.BuildCondBranch(fn2)
|
||||
if orig_res_f is None:
|
||||
raise ValueError("fn2 must have a return value.")
|
||||
context_f.ExitResult(res_f)
|
||||
context_f.Exit()
|
||||
|
||||
if not strict:
|
||||
orig_res_t = _UnpackIfSingleton(orig_res_t)
|
||||
orig_res_f = _UnpackIfSingleton(orig_res_f)
|
||||
|
||||
# Check that the return values of the two branches have the same structure.
|
||||
try:
|
||||
nest.assert_same_structure(orig_res_t, orig_res_f)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Incompatible return types of fn1 and fn2: {}".format(e))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Incompatible return values of fn1 and fn2: {}".format(e))
|
||||
|
||||
# Add the final merge to the graph.
|
||||
if len(res_t) != len(res_f):
|
||||
raise ValueError("fn1 and fn2 must return the same number of results.")
|
||||
if not res_t:
|
||||
raise ValueError("fn1 and fn2 must return at least one result.")
|
||||
for x, y in zip(res_t, res_f):
|
||||
|
||||
res_t_flat = nest.flatten(res_t)
|
||||
res_f_flat = nest.flatten(res_f)
|
||||
|
||||
for x, y in zip(res_t_flat, res_f_flat):
|
||||
assert ((isinstance(x, ops.IndexedSlices) and
|
||||
isinstance(y, ops.IndexedSlices)) or
|
||||
(isinstance(x, sparse_tensor.SparseTensor) and
|
||||
@ -1765,14 +1798,20 @@ def cond(pred, fn1, fn2, name=None):
|
||||
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
|
||||
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
|
||||
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
|
||||
merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
|
||||
merges = _convert_flows_to_tensorarrays(orig_res, merges)
|
||||
|
||||
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
|
||||
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
|
||||
|
||||
# Add to collections
|
||||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
|
||||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
|
||||
|
||||
return merges[0] if len(merges) == 1 else merges
|
||||
merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
|
||||
|
||||
# Singleton lists and tuples are automatically unpacked if strict == False.
|
||||
if not strict:
|
||||
merges = _UnpackIfSingleton(merges)
|
||||
return merges
|
||||
|
||||
|
||||
def _resource_safe_shape(t):
|
||||
@ -2415,8 +2454,8 @@ class WhileContext(ControlFlowContext):
|
||||
# Store body_result to keep track of TensorArrays returned by body
|
||||
original_body_result = body_result
|
||||
# Convert TensorArrays returned by body into their flow variables
|
||||
flat_result = nest.flatten(body_result)
|
||||
result = _convert_tensorarrays_to_flows(flat_result)
|
||||
result = nest.map_structure(_convert_tensorarray_to_flow,
|
||||
nest.flatten(body_result))
|
||||
result = ops.convert_n_to_tensor_or_indexed_slices(result)
|
||||
|
||||
# Add NextIteration and the back edges to complete the loop.
|
||||
@ -2446,9 +2485,9 @@ class WhileContext(ControlFlowContext):
|
||||
|
||||
# Keep original_loop_vars to identify which are TensorArrays
|
||||
original_loop_vars = loop_vars
|
||||
flat_loop_vars = nest.flatten(loop_vars)
|
||||
# Convert TensorArrays to their flow variables
|
||||
loop_vars = _convert_tensorarrays_to_flows(flat_loop_vars)
|
||||
loop_vars = nest.map_structure(_convert_tensorarray_to_flow,
|
||||
nest.flatten(loop_vars))
|
||||
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
|
||||
try:
|
||||
self.Enter()
|
||||
@ -2820,7 +2859,7 @@ def tuple(tensors, name=None, control_inputs=None):
|
||||
return tpl
|
||||
|
||||
|
||||
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
|
||||
"""Create a case operation.
|
||||
|
||||
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
|
||||
@ -2837,6 +2876,13 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
are returned immediately. If none of the predicates evaluate to True, this
|
||||
operation returns the tensors generated by `default`.
|
||||
|
||||
`tf.case` supports nested structures as implemented in
|
||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
||||
behavior is disabled by passing `strict=True`.
|
||||
|
||||
Example 1:
|
||||
Pseudocode:
|
||||
```
|
||||
@ -2877,6 +2923,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
callable which returns a list of tensors.
|
||||
default: A callable that returns a list of tensors.
|
||||
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
|
||||
strict: A boolean that enables/disables 'strict' mode; see above.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -2941,20 +2988,31 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
|
||||
# Create an empty tensor, or list, with the right type and shape
|
||||
with ops.name_scope("case_create_empty"):
|
||||
dummy_value = default()
|
||||
def _create_empty_constant(dtype, shape):
|
||||
value = ("" if dtype == dtypes.string else dtype.as_numpy_dtype())
|
||||
if shape.ndims is None:
|
||||
return array_ops.constant(value, dtype=dtype)
|
||||
else:
|
||||
temp_shape = [1 if x.value is None else x.value for x in shape]
|
||||
result = array_ops.constant(value, shape=temp_shape, dtype=dtype)
|
||||
result._shape = shape # pylint: disable=protected-access
|
||||
return result
|
||||
|
||||
def _correct_empty(v):
|
||||
if isinstance(v, ops.Operation):
|
||||
return no_op()
|
||||
elif v.dtype == dtypes.string:
|
||||
return array_ops.constant("")
|
||||
elif isinstance(v, tensor_array_ops.TensorArray):
|
||||
return v
|
||||
elif not hasattr(v, "dtype"):
|
||||
return ops.convert_to_tensor(v)
|
||||
elif isinstance(v, sparse_tensor.SparseTensor):
|
||||
return sparse_tensor.SparseTensor(indices=[[0] * len(v.get_shape())],
|
||||
values=[v.dtype.as_numpy_dtype()],
|
||||
dense_shape=v.get_shape())
|
||||
else:
|
||||
return array_ops.constant(v.dtype.as_numpy_dtype())
|
||||
return _create_empty_constant(v.dtype, v.get_shape())
|
||||
|
||||
if isinstance(dummy_value, collections.Sequence):
|
||||
dummy_type = type(dummy_value)
|
||||
empty = lambda: dummy_type(_correct_empty(v) for v in dummy_value)
|
||||
else:
|
||||
empty = lambda: _correct_empty(dummy_value)
|
||||
empty = lambda: nest.map_structure(_correct_empty, default())
|
||||
|
||||
# case_sequence = [
|
||||
# cond(~p3 & ~p2 & ~p1, default, empty),
|
||||
@ -2972,7 +3030,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
prev_case = cond(
|
||||
cp, fn,
|
||||
empty if i == 0 else lambda: prev_case,
|
||||
name="If_%d" % i)
|
||||
strict=strict, name="If_%d" % i)
|
||||
return prev_case
|
||||
|
||||
if exclusive:
|
||||
@ -2994,6 +3052,8 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
else:
|
||||
case_seq = _build_case()
|
||||
|
||||
if not strict:
|
||||
case_seq = _UnpackIfSingleton(case_seq)
|
||||
return case_seq
|
||||
|
||||
|
||||
|
@ -18,11 +18,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -37,9 +42,14 @@ from tensorflow.python.ops import variables
|
||||
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import momentum
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.protobuf import compare
|
||||
|
||||
|
||||
TestTuple = collections.namedtuple("TestTuple", "a b")
|
||||
SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
|
||||
|
||||
|
||||
class GroupTestCase(TensorFlowTestCase):
|
||||
|
||||
def _StripNode(self, nd):
|
||||
@ -334,5 +344,340 @@ class ContextTest(TensorFlowTestCase):
|
||||
control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
|
||||
|
||||
|
||||
def _GetNestedShape(nested):
|
||||
def _GetShape(tensor):
|
||||
if isinstance(tensor, tensor_array_ops.TensorArray):
|
||||
return tensor_array_ops.TensorArray
|
||||
elif isinstance(tensor, ops.IndexedSlices):
|
||||
return tensor.dense_shape
|
||||
else:
|
||||
return tensor.get_shape()
|
||||
|
||||
return nest.map_structure(_GetShape, nested)
|
||||
|
||||
|
||||
def _CreateTensorArray(size, shape):
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
|
||||
clear_after_read=False)
|
||||
for i in range(size):
|
||||
ta = ta.write(i, array_ops.zeros(shape))
|
||||
return ta
|
||||
|
||||
|
||||
def _RawNestedShape(nested_shape):
|
||||
def _RawShape(shape):
|
||||
if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
|
||||
return [x.value for x in shape]
|
||||
else:
|
||||
return None
|
||||
return nest.map_structure(_RawShape, nested_shape)
|
||||
|
||||
|
||||
# TODO(yori): Add tests for indexed slices.
|
||||
class DataTypesTest(TensorFlowTestCase):
|
||||
|
||||
def assertAllEqualNested(self, a, b):
|
||||
if isinstance(a, (list, tuple)):
|
||||
for entry_a, entry_b in zip(a, b):
|
||||
self.assertAllEqualNested(entry_a, entry_b)
|
||||
else:
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
def _testShape(self, fn_true, fn_false, expected_shape,
|
||||
strict=False):
|
||||
condition = array_ops.placeholder(dtypes.bool)
|
||||
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
|
||||
strict=strict)
|
||||
self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)),
|
||||
_RawNestedShape(expected_shape))
|
||||
|
||||
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
|
||||
strict=strict)
|
||||
self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)),
|
||||
_RawNestedShape(expected_shape))
|
||||
|
||||
def _testReturnValues(self, fn_true, fn_false, expected_value_true,
|
||||
expected_value_false, strict=False,
|
||||
check_cond=True):
|
||||
condition = array_ops.placeholder(dtypes.bool)
|
||||
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
|
||||
strict=strict)
|
||||
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
|
||||
strict=strict)
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
result_cond, result_case = sess.run([output_cond, output_case],
|
||||
feed_dict={condition: True})
|
||||
self.assertAllEqualNested(result_cond, expected_value_true)
|
||||
if check_cond:
|
||||
self.assertAllEqualNested(result_case, expected_value_true)
|
||||
result_cond, result_case = sess.run([output_cond, output_case],
|
||||
feed_dict={condition: False})
|
||||
self.assertAllEqualNested(result_cond, expected_value_false)
|
||||
if check_cond:
|
||||
self.assertAllEqualNested(result_case, expected_value_false)
|
||||
|
||||
def test_int(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: 1
|
||||
fn_false = lambda: 2
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 1, 2)
|
||||
self._testShape(fn_true, fn_false, shape, strict=True)
|
||||
self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
|
||||
|
||||
def test_float(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: 1.0
|
||||
fn_false = lambda: 2.0
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
|
||||
|
||||
def test_noop(self):
|
||||
shape = tensor_shape.TensorShape(None)
|
||||
self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
|
||||
self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
|
||||
True, False, check_cond=False)
|
||||
|
||||
def test_string(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: "abc"
|
||||
fn_false = lambda: "xyz"
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
|
||||
|
||||
def test_variable(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: variables.Variable(3.0)
|
||||
fn_false = lambda: variables.Variable(4.0)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
|
||||
|
||||
def test_none(self):
|
||||
fn_none = lambda: None
|
||||
fn_tensor = lambda: constant_op.constant(1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
|
||||
|
||||
def test_tensors(self):
|
||||
def _BuildTrueBranch(dtype):
|
||||
def _Build():
|
||||
return (array_ops.zeros([2, 2], dtype=dtype),
|
||||
array_ops.ones([3, 3], dtype=dtype))
|
||||
return _Build
|
||||
|
||||
def _BuildFalseBranch(dtype):
|
||||
def _Build():
|
||||
return (array_ops.ones([2, 2], dtype=dtype),
|
||||
array_ops.zeros([3, 3], dtype=dtype))
|
||||
return _Build
|
||||
|
||||
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
|
||||
shape = (tensor_shape.TensorShape([2, 2]),
|
||||
tensor_shape.TensorShape([3, 3]))
|
||||
fn_true = _BuildTrueBranch(dtype)
|
||||
fn_false = _BuildFalseBranch(dtype)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false,
|
||||
(np.zeros([2, 2]), np.ones([3, 3])),
|
||||
(np.ones([2, 2]), np.zeros([3, 3])))
|
||||
|
||||
def test_tensors_unknown_shape(self):
|
||||
def _BuildTrueBranch(dtype):
|
||||
def _Build():
|
||||
tensor = array_ops.zeros([2, 2], dtype=dtype)
|
||||
tensor._shape = tensor_shape.TensorShape(None)
|
||||
return tensor
|
||||
return _Build
|
||||
|
||||
def _BuildFalseBranch(dtype):
|
||||
def _Build():
|
||||
tensor = array_ops.ones([2, 2], dtype=dtype)
|
||||
tensor._shape = tensor_shape.TensorShape(None)
|
||||
return tensor
|
||||
return _Build
|
||||
|
||||
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
|
||||
shape = tensor_shape.TensorShape(None)
|
||||
fn_true = _BuildTrueBranch(dtype)
|
||||
fn_false = _BuildFalseBranch(dtype)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false,
|
||||
np.zeros([2, 2]), np.ones([2, 2]))
|
||||
|
||||
def test_sparse_tensors(self):
|
||||
shape = tensor_shape.TensorShape([None, None])
|
||||
|
||||
def FnTrue():
|
||||
return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
|
||||
values=[1, 2], dense_shape=[3, 4])]
|
||||
|
||||
def FnFalse():
|
||||
return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
|
||||
values=[3, 4], dense_shape=[3, 4])]
|
||||
|
||||
value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
|
||||
values=[1, 2], dense_shape=[3, 4])
|
||||
value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
|
||||
values=[3, 4], dense_shape=[3, 4])
|
||||
self._testShape(FnTrue, FnFalse, shape)
|
||||
self._testReturnValues(FnTrue, FnFalse, value1, value2)
|
||||
self._testShape(FnTrue, FnFalse, [shape], strict=True)
|
||||
self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True)
|
||||
|
||||
def test_tensors_with_partially_specified_shapes(self):
|
||||
def _BuildBranch(dtype, shape):
|
||||
def _Build():
|
||||
a = array_ops.zeros([2, 2], dtype=dtype)
|
||||
b = array_ops.zeros([5], dtype=dtype)
|
||||
c = array_ops.ones([3, 3], dtype=dtype)
|
||||
a._shape = tensor_shape.TensorShape(shape[0])
|
||||
b._shape = tensor_shape.TensorShape(shape[1])
|
||||
c._shape = tensor_shape.TensorShape(shape[2])
|
||||
return a, b, c
|
||||
return _Build
|
||||
|
||||
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
|
||||
shape = (tensor_shape.TensorShape([None, 2]),
|
||||
tensor_shape.TensorShape([None]),
|
||||
tensor_shape.TensorShape([3, None]))
|
||||
fn_true = _BuildBranch(dtype, shape)
|
||||
fn_false = _BuildBranch(dtype, shape)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false,
|
||||
(np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
|
||||
(np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])))
|
||||
|
||||
def test_tensor_arrays(self):
|
||||
element_shape = tensor_shape.TensorShape([2])
|
||||
ta1 = _CreateTensorArray(4, element_shape)
|
||||
ta2 = _CreateTensorArray(4, element_shape)
|
||||
shape = tensor_array_ops.TensorArray
|
||||
fn_true = lambda: ta1
|
||||
fn_false = lambda: ta2
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
|
||||
def test_tensor_array_reads(self):
|
||||
shape = tensor_shape.TensorShape([2])
|
||||
ta = _CreateTensorArray(4, shape)
|
||||
fn_true = lambda: ta.read(0)
|
||||
fn_false = lambda: ta.read(1)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
|
||||
def test_list(self):
|
||||
shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
|
||||
tensor_shape.TensorShape([])]
|
||||
fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
|
||||
fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
|
||||
|
||||
def test_non_strict(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_tensor = lambda: constant_op.constant(1)
|
||||
fn_list = lambda: [constant_op.constant(2)]
|
||||
fn_tuple = lambda: (constant_op.constant(3),)
|
||||
self._testShape(fn_tensor, fn_list, shape)
|
||||
self._testShape(fn_tensor, fn_tuple, shape)
|
||||
self._testShape(fn_list, fn_tuple, shape)
|
||||
self._testReturnValues(fn_tensor, fn_list, 1, 2)
|
||||
self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
|
||||
self._testReturnValues(fn_list, fn_tuple, 2, 3)
|
||||
|
||||
def test_singleton_strict(self):
|
||||
fn_tensor = lambda: constant_op.constant(1)
|
||||
fn_list = lambda: [constant_op.constant(2)]
|
||||
fn_tuple = lambda: (constant_op.constant(3),)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
|
||||
strict=True)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
|
||||
strict=True)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
|
||||
strict=True)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
|
||||
strict=True)
|
||||
|
||||
def test_singleton_list(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: [constant_op.constant(1)]
|
||||
fn_false = lambda: [constant_op.constant(3)]
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 1, 3)
|
||||
self._testShape(fn_true, fn_false, [shape], strict=True)
|
||||
self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
|
||||
|
||||
def test_singleton_tuple(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: (constant_op.constant(1),)
|
||||
fn_false = lambda: (constant_op.constant(3),)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 1, 3)
|
||||
self._testShape(fn_true, fn_false, (shape,), strict=True)
|
||||
self._testReturnValues(fn_true, fn_false, (1,), (3,),
|
||||
strict=True)
|
||||
|
||||
def test_singleton_namedtuple(self):
|
||||
shape = tensor_shape.TensorShape([])
|
||||
fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
|
||||
fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, 1, 3)
|
||||
self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
|
||||
strict=True)
|
||||
self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
|
||||
SingletonTestTuple(3), strict=True)
|
||||
|
||||
def test_tuple(self):
|
||||
shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
|
||||
fn_true = lambda: (constant_op.constant(1), 2)
|
||||
fn_false = lambda: (constant_op.constant(3), 4)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
|
||||
|
||||
def test_namedtuple(self):
|
||||
shape = TestTuple(tensor_shape.TensorShape([]),
|
||||
tensor_shape.TensorShape([]))
|
||||
fn_true = lambda: TestTuple(constant_op.constant(1), 2)
|
||||
fn_false = lambda: TestTuple(constant_op.constant(3), 4)
|
||||
self._testShape(fn_true, fn_false, shape)
|
||||
self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
|
||||
|
||||
def test_nested(self):
|
||||
shape = [tensor_shape.TensorShape([]),
|
||||
TestTuple(tensor_shape.TensorShape([]),
|
||||
[tensor_shape.TensorShape([]),
|
||||
tensor_shape.TensorShape([])]),
|
||||
tensor_shape.TensorShape([5, 5]),
|
||||
tensor_shape.TensorShape([])]
|
||||
|
||||
def FnTrue():
|
||||
return [constant_op.constant(1),
|
||||
TestTuple(constant_op.constant(2), [3, 4]),
|
||||
array_ops.zeros([5, 5]), 6]
|
||||
|
||||
def FnFalse():
|
||||
return [constant_op.constant(11),
|
||||
TestTuple(constant_op.constant(12), [13, 14]),
|
||||
array_ops.ones([5, 5]), 16]
|
||||
|
||||
self._testShape(FnTrue, FnFalse, shape)
|
||||
self._testReturnValues(FnTrue, FnFalse,
|
||||
[1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
|
||||
[11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -33,16 +33,16 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _do_gather(params, ids, validate_indices=True, name=None):
|
||||
def _do_gather(params, ids, name=None):
|
||||
"""Deals with doing gather differently for resource variables."""
|
||||
if isinstance(params, resource_variable_ops.ResourceVariable):
|
||||
return params.sparse_read(ids, name=name)
|
||||
return array_ops.gather(
|
||||
params, ids, name=name, validate_indices=validate_indices)
|
||||
return array_ops.gather(params, ids, name=name)
|
||||
|
||||
|
||||
def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
validate_indices=True, max_norm=None):
|
||||
validate_indices=True, # pylint: disable=unused-argument
|
||||
max_norm=None):
|
||||
"""Looks up `ids` in a list of embedding tensors.
|
||||
|
||||
This function is used to perform parallel lookups on the list of
|
||||
@ -82,7 +82,10 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
|
||||
is `"mod"`.
|
||||
name: A name for the operation (optional).
|
||||
validate_indices: Whether or not to validate gather indices.
|
||||
validate_indices: DEPRECATED. If this operation is assigned to CPU, values
|
||||
in `indices` are always validated to be within range. If assigned to GPU,
|
||||
out-of-bound indices result in safe but unspecified behavior, which may
|
||||
include raising an error.
|
||||
max_norm: If not None, embedding values are l2-normalized to the value of
|
||||
max_norm.
|
||||
|
||||
@ -92,7 +95,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
Raises:
|
||||
ValueError: If `params` is empty.
|
||||
"""
|
||||
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
|
||||
if params in (None, (), []):
|
||||
raise ValueError("Need at least one param")
|
||||
if isinstance(params, variables.PartitionedVariable):
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
@ -114,9 +117,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
|
||||
if np == 1:
|
||||
with ops.colocate_with(params[0]):
|
||||
return maybe_normalize(
|
||||
_do_gather(
|
||||
params[0], ids, validate_indices=validate_indices, name=name))
|
||||
return maybe_normalize(_do_gather(params[0], ids, name=name))
|
||||
else:
|
||||
ids = ops.convert_to_tensor(ids, name="ids")
|
||||
flat_ids = array_ops.reshape(ids, [-1])
|
||||
@ -176,9 +177,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
partitioned_result = []
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result.append(
|
||||
_do_gather(params[p], gather_ids[p],
|
||||
validate_indices=validate_indices))
|
||||
partitioned_result.append(_do_gather(params[p], gather_ids[p]))
|
||||
# Stitch these back together
|
||||
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
|
||||
name=name)
|
||||
|
@ -95,7 +95,7 @@ def parse_numpy_printoption(kv_str):
|
||||
"Setting '%s' from the command line is not supported." % k)
|
||||
try:
|
||||
v = (v_type(v_str) if v_type is not bool
|
||||
else flags.BooleanParser().Parse(v_str))
|
||||
else flags.BooleanParser().parse(v_str))
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(e.message)
|
||||
np.set_printoptions(**{k: v})
|
||||
|
@ -106,7 +106,7 @@ class Coordinator(object):
|
||||
After a thread has called `coord.request_stop()` the other threads have a
|
||||
fixed time to stop, this is called the 'stop grace period' and defaults to 2
|
||||
minutes. If any of the threads is still alive after the grace period expires
|
||||
`coord.join()` raises a RuntimeException reporting the laggards.
|
||||
`coord.join()` raises a RuntimeError reporting the laggards.
|
||||
|
||||
```python
|
||||
try:
|
||||
@ -117,7 +117,7 @@ class Coordinator(object):
|
||||
...start thread N...(coord, ...)
|
||||
# Wait for all the threads to terminate, give them 10s grace period
|
||||
coord.join(threads, stop_grace_period_secs=10)
|
||||
except RuntimeException:
|
||||
except RuntimeError:
|
||||
...one of the threads took more than 10s to stop after request_stop()
|
||||
...was called.
|
||||
except Exception:
|
||||
|
@ -68,6 +68,7 @@ See the @{$python/train} guide.
|
||||
@@LoggingTensorHook
|
||||
@@StopAtStepHook
|
||||
@@CheckpointSaverHook
|
||||
@@CheckpointSaverListener
|
||||
@@NewCheckpointReader
|
||||
@@StepCounterHook
|
||||
@@NanLossDuringTrainingError
|
||||
@ -128,6 +129,7 @@ from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
|
||||
from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener
|
||||
from tensorflow.python.training.basic_session_run_hooks import StepCounterHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError
|
||||
from tensorflow.python.training.basic_session_run_hooks import NanTensorHook
|
||||
|
@ -39,6 +39,8 @@ py_binary(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/tensorboard/backend:application",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_file_inspector",
|
||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||
"//tensorflow/tensorboard/plugins/text:text_plugin",
|
||||
"@org_pocoo_werkzeug//:werkzeug",
|
||||
],
|
||||
)
|
||||
|
@ -65,9 +65,6 @@ py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_multiplexer",
|
||||
"//tensorflow/tensorboard/plugins/debugger:debugger_plugin",
|
||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||
"//tensorflow/tensorboard/plugins/text:text_plugin",
|
||||
"@org_pocoo_werkzeug//:werkzeug",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -43,9 +43,6 @@ from tensorflow.tensorboard.backend import http_util
|
||||
from tensorflow.tensorboard.backend import process_graph
|
||||
from tensorflow.tensorboard.backend.event_processing import event_accumulator
|
||||
from tensorflow.tensorboard.backend.event_processing import event_multiplexer
|
||||
from tensorflow.tensorboard.plugins.debugger import debugger_plugin
|
||||
from tensorflow.tensorboard.plugins.projector import projector_plugin
|
||||
from tensorflow.tensorboard.plugins.text import text_plugin
|
||||
|
||||
|
||||
DEFAULT_SIZE_GUIDANCE = {
|
||||
@ -97,18 +94,27 @@ class _OutputFormat(object):
|
||||
CSV = 'csv'
|
||||
|
||||
|
||||
def standard_tensorboard_wsgi(logdir, purge_orphaned_data, reload_interval):
|
||||
"""Construct a TensorBoardWSGIApp with standard plugins and multiplexer."""
|
||||
def standard_tensorboard_wsgi(
|
||||
logdir,
|
||||
purge_orphaned_data,
|
||||
reload_interval,
|
||||
plugins):
|
||||
"""Construct a TensorBoardWSGIApp with standard plugins and multiplexer.
|
||||
|
||||
Args:
|
||||
logdir: The path to the directory containing events files.
|
||||
purge_orphaned_data: Whether to purge orphaned data.
|
||||
reload_interval: The interval at which the backend reloads more data in
|
||||
seconds.
|
||||
plugins: A list of plugins for TensorBoard to initialize.
|
||||
|
||||
Returns:
|
||||
The new TensorBoard WSGI application.
|
||||
"""
|
||||
multiplexer = event_multiplexer.EventMultiplexer(
|
||||
size_guidance=DEFAULT_SIZE_GUIDANCE,
|
||||
purge_orphaned_data=purge_orphaned_data)
|
||||
|
||||
plugins = [
|
||||
debugger_plugin.DebuggerPlugin(),
|
||||
projector_plugin.ProjectorPlugin(),
|
||||
text_plugin.TextPlugin(),
|
||||
]
|
||||
|
||||
return TensorBoardWSGIApp(logdir, plugins, multiplexer, reload_interval)
|
||||
|
||||
|
||||
|
@ -54,15 +54,18 @@ from tensorflow.tensorboard.plugins import base_plugin
|
||||
class FakePlugin(base_plugin.TBPlugin):
|
||||
"""A plugin with no functionality."""
|
||||
|
||||
def __init__(self, plugin_name, is_active_value):
|
||||
def __init__(self, plugin_name, is_active_value, routes_mapping):
|
||||
"""Constructs a fake plugin.
|
||||
|
||||
Args:
|
||||
plugin_name: The name of this plugin.
|
||||
is_active_value: Whether the plugin is active.
|
||||
routes_mapping: A dictionary mapping from route (string URL path) to the
|
||||
method called when a user issues a request to that route.
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
self._is_active_value = is_active_value
|
||||
self._routes_mapping = routes_mapping
|
||||
|
||||
def get_plugin_apps(self, multiplexer, logdir):
|
||||
"""Returns a mapping from routes to handlers offered by this plugin.
|
||||
@ -72,9 +75,9 @@ class FakePlugin(base_plugin.TBPlugin):
|
||||
logdir: The path to the directory containing logs.
|
||||
|
||||
Returns:
|
||||
An empty dict. This plugin offers no routes.
|
||||
A dictionary mapping from routes to handlers offered by this plugin.
|
||||
"""
|
||||
return {}
|
||||
return self._routes_mapping
|
||||
|
||||
def is_active(self):
|
||||
"""Returns whether this plugin is active.
|
||||
@ -97,8 +100,8 @@ class TensorboardServerTest(test.TestCase):
|
||||
size_guidance=application.DEFAULT_SIZE_GUIDANCE,
|
||||
purge_orphaned_data=True)
|
||||
plugins = [
|
||||
FakePlugin(plugin_name='foo', is_active_value=True),
|
||||
FakePlugin(plugin_name='bar', is_active_value=False)
|
||||
FakePlugin(plugin_name='foo', is_active_value=True, routes_mapping={}),
|
||||
FakePlugin(plugin_name='bar', is_active_value=False, routes_mapping={})
|
||||
]
|
||||
app = application.TensorBoardWSGIApp(
|
||||
self.temp_dir, plugins, multiplexer, reload_interval=0)
|
||||
@ -476,10 +479,41 @@ class TensorBoardAssetsTest(test.TestCase):
|
||||
def testTagFound(self):
|
||||
tag = application.get_tensorboard_tag()
|
||||
self.assertTrue(tag)
|
||||
app = application.standard_tensorboard_wsgi('', True, 60)
|
||||
app = application.standard_tensorboard_wsgi('', True, 60, [])
|
||||
self.assertEqual(app.tag, tag)
|
||||
|
||||
|
||||
class TensorBoardPluginsTest(test.TestCase):
|
||||
|
||||
def testPluginsAdded(self):
|
||||
|
||||
def foo_handler():
|
||||
pass
|
||||
|
||||
def bar_handler():
|
||||
pass
|
||||
|
||||
plugins = [
|
||||
FakePlugin(
|
||||
plugin_name='foo',
|
||||
is_active_value=True,
|
||||
routes_mapping={'/foo_route': foo_handler}),
|
||||
FakePlugin(
|
||||
plugin_name='bar',
|
||||
is_active_value=True,
|
||||
routes_mapping={'/bar_route': bar_handler}),
|
||||
]
|
||||
|
||||
# The application should have added routes for both plugins.
|
||||
app = application.standard_tensorboard_wsgi('', True, 60, plugins)
|
||||
|
||||
# The routes are prefixed with /data/plugin/[plugin name].
|
||||
self.assertDictContainsSubset({
|
||||
'/data/plugin/foo/foo_route': foo_handler,
|
||||
'/data/plugin/bar/bar_route': bar_handler,
|
||||
}, app.data_applications)
|
||||
|
||||
|
||||
class TensorboardSimpleServerConstructionTest(test.TestCase):
|
||||
"""Tests that the default HTTP server is constructed without error.
|
||||
|
||||
@ -533,14 +567,18 @@ class TensorBoardApplcationConstructionTest(test.TestCase):
|
||||
# Fails if there is an unnamed plugin
|
||||
with self.assertRaises(ValueError):
|
||||
# This plugin lacks a name.
|
||||
plugins = [FakePlugin(plugin_name=None, is_active_value=True)]
|
||||
plugins = [
|
||||
FakePlugin(plugin_name=None, is_active_value=True, routes_mapping={})
|
||||
]
|
||||
application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0)
|
||||
|
||||
# Fails if there are two plugins with same name
|
||||
with self.assertRaises(ValueError):
|
||||
plugins = [
|
||||
FakePlugin(plugin_name='foo', is_active_value=True),
|
||||
FakePlugin(plugin_name='foo', is_active_value=True),
|
||||
FakePlugin(
|
||||
plugin_name='foo', is_active_value=True, routes_mapping={}),
|
||||
FakePlugin(
|
||||
plugin_name='foo', is_active_value=True, routes_mapping={}),
|
||||
]
|
||||
application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0)
|
||||
|
||||
|
@ -59,12 +59,16 @@ tf-audio-dashboard displays a dashboard that loads audio from a TensorFlow run.
|
||||
</style>
|
||||
</template>
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfAudioDashboard = Polymer({
|
||||
is: "tf-audio-dashboard",
|
||||
factoryImpl: function(backend) {
|
||||
this.backend = backend;
|
||||
},
|
||||
properties: {
|
||||
dataType: {value: "audio"},
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("audio"),
|
||||
TF.Dashboard.ReloadBehavior("tf-audio-loader"),
|
||||
TF.Backend.Behavior
|
||||
],
|
||||
|
@ -41,6 +41,7 @@ tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"categorizer.ts",
|
||||
"dashboard-behavior.ts",
|
||||
"reload-behavior.ts",
|
||||
],
|
||||
typings = [
|
||||
@ -89,6 +90,7 @@ tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [
|
||||
"categorizer.ts",
|
||||
"dashboard-behavior.ts",
|
||||
"reload-behavior.ts",
|
||||
],
|
||||
deps = [
|
||||
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
module TF.Dashboard {
|
||||
/**
|
||||
* A behavior that TensorBoard dashboards must implement. This behavior serves
|
||||
* the purpose of an interface.
|
||||
*/
|
||||
export function DashboardBehavior(dashboardName) {
|
||||
return {
|
||||
properties: {
|
||||
name: {
|
||||
type: String,
|
||||
value: dashboardName,
|
||||
readOnly: true,
|
||||
},
|
||||
},
|
||||
// This method is called when the dashboard reloads, either when the
|
||||
// dashboard is first visited, periodically reloaded, or manually reloaded
|
||||
// via the user clicking the button. Note that dashboard custom elements
|
||||
// that use TF.Dashboard.ReloadBehavior already implement a reload method.
|
||||
reload() {
|
||||
throw Error(
|
||||
'The ' + dashboardName + ' dashboard does not implement reload.');
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
@ -22,4 +22,5 @@ limitations under the License.
|
||||
<link rel="import" href="tf-downloader.html">
|
||||
<link rel="import" href="tf-no-data-warning.html">
|
||||
|
||||
<script src="dashboard-behavior.js"></script>
|
||||
<script src="reload-behavior.js"></script>
|
||||
|
@ -139,17 +139,9 @@ Properties out:
|
||||
},
|
||||
},
|
||||
observers: [
|
||||
"_onBackendUpdate(backend)",
|
||||
"_logdirSet(logdir)",
|
||||
],
|
||||
ready: function() {
|
||||
// Populate the logdir.
|
||||
this.backend.logdir().then(logdirObject => {
|
||||
this.set('logdir', logdirObject.logdir);
|
||||
}).catch(e => {
|
||||
// Fetching the logdir failed. Prevent the exception from logging to
|
||||
// console. The console already logs a 404 network event.
|
||||
});
|
||||
},
|
||||
_toggleAll: function() {
|
||||
this.$.multiCheckbox.toggleAll();
|
||||
},
|
||||
@ -157,8 +149,21 @@ Properties out:
|
||||
_breakString: function(originalString) {
|
||||
return originalString.replace(/([\/=\-_,])/g, "$1<wbr>");
|
||||
},
|
||||
_onBackendUpdate: function(backend) {
|
||||
if (backend === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
// When the backend is set, the selector can request the logdir.
|
||||
backend.logdir().then(logdirObject => {
|
||||
this.set('logdir', logdirObject.logdir);
|
||||
}).catch(e => {
|
||||
// Fetching the logdir failed. Prevent the exception from logging to
|
||||
// console. The console already logs a 404 network event.
|
||||
});
|
||||
},
|
||||
_logdirSet: function(logdir) {
|
||||
if (!logdir) {
|
||||
if (logdir === undefined) {
|
||||
// The logdir has not been set yet.
|
||||
return;
|
||||
}
|
||||
|
@ -101,9 +101,13 @@ contains vz-distribution-charts embedded inside tf-panes-helper's.
|
||||
</template>
|
||||
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfDistributionDashboard = Polymer({
|
||||
is: "tf-distribution-dashboard",
|
||||
factoryImpl: function(backend) {
|
||||
this.backend = backend;
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("distributions"),
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
|
@ -619,6 +619,18 @@ Polymer({
|
||||
renderHierarchy: Object,
|
||||
name: String,
|
||||
colorBy: String,
|
||||
|
||||
// For each render hierarchy, we only fit it to the viewport once (when the scene is attached to
|
||||
// the DOM). We do not fit the hierarchy again (unless the user clicks the reset button). For
|
||||
// instance, if the user enters a certain view in the graph, switches to another dashboard, and
|
||||
// returns to the graph dashboard, the user expects the previous view. These properties enable
|
||||
// that behavior.
|
||||
|
||||
/** Whether the scene has fit the current render hierarchy (to the viewport) at least once. */
|
||||
_hasRenderHierarchyBeenFitOnce: Boolean,
|
||||
/** Whether this scene element is currently attached to a parent element. */
|
||||
_isAttached: Boolean,
|
||||
|
||||
/** @type {d3_zoom} d3 zoom object */
|
||||
_zoom: Object,
|
||||
highlightedNode: {
|
||||
@ -723,7 +735,10 @@ Polymer({
|
||||
},
|
||||
observers: [
|
||||
'_colorByChanged(colorBy)',
|
||||
'_buildAndFit(renderHierarchy)',
|
||||
'_renderHierarchyChanged(renderHierarchy)',
|
||||
// Animation and fitting must come after the observer for the hierarchy changing because we must
|
||||
// first build the render hierarchy.
|
||||
'_animateAndFit(_isAttached, renderHierarchy)',
|
||||
'_updateHealthPills(nodeNamesToHealthPills, healthPillStepIndex)',
|
||||
],
|
||||
getNode: function(nodeName) {
|
||||
@ -826,9 +841,24 @@ Polymer({
|
||||
tf.graph.layout.PARAMS.minimap.size,
|
||||
tf.graph.layout.PARAMS.subscene.meta.labelHeight);
|
||||
},
|
||||
_buildAndFit: function(renderHierarchy) {
|
||||
attached: function() {
|
||||
this.set('_isAttached', true);
|
||||
},
|
||||
detached: function() {
|
||||
this.set('_isAttached', false);
|
||||
},
|
||||
_renderHierarchyChanged: function(renderHierarchy) {
|
||||
this._hasRenderHierarchyBeenFitOnce = false;
|
||||
this._resetState();
|
||||
this._build(renderHierarchy);
|
||||
},
|
||||
_animateAndFit: function(isAttached, renderHierarchy) {
|
||||
if (this._hasRenderHierarchyBeenFitOnce || !isAttached) {
|
||||
// Do not animate and fit if the scene has already fitted this render hierarchy once. Or if
|
||||
// the graph dashboard is not attached (in which case the scene lacks DOM info for fitting).
|
||||
return;
|
||||
}
|
||||
|
||||
// Fit to screen after the graph is done animating.
|
||||
setTimeout(this.fit.bind(this), tf.graph.layout.PARAMS.animation.duration);
|
||||
},
|
||||
@ -881,6 +911,7 @@ Polymer({
|
||||
}
|
||||
},
|
||||
fit: function() {
|
||||
this._hasRenderHierarchyBeenFitOnce = true;
|
||||
tf.graph.scene.fit(this.$.svg, this.$.root, this._zoom, function() {
|
||||
this._zoomed = false;
|
||||
}.bind(this));
|
||||
|
@ -104,9 +104,14 @@ out-hierarchy-params="{{_hierarchyParams}}"
|
||||
|
||||
<script>
|
||||
(function() {
|
||||
Polymer({
|
||||
TF.Dashboard.TfGraphDashboard = Polymer({
|
||||
is: 'tf-graph-dashboard',
|
||||
factoryImpl: function(backend, debuggerDataEnabled) {
|
||||
this.backend = backend;
|
||||
this.debuggerDataEnabled = debuggerDataEnabled;
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("graphs"),
|
||||
TF.Dashboard.ReloadBehavior("tf-graph-dashboard"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
|
@ -121,9 +121,13 @@ contains vz-histogram-timeseries embedded inside tf-panes-helper's.
|
||||
</template>
|
||||
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfHistogramDashboard = Polymer({
|
||||
is: "tf-histogram-dashboard",
|
||||
factoryImpl: function(backend) {
|
||||
this.backend = backend;
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("histograms"),
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
|
@ -106,8 +106,11 @@ tf-image-dashboard displays a dashboard that loads images from a TensorFlow run.
|
||||
</style>
|
||||
</template>
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfImageDashboard = Polymer({
|
||||
is: "tf-image-dashboard",
|
||||
factoryImpl: function(backend) {
|
||||
this.backend = backend;
|
||||
},
|
||||
properties: {
|
||||
backend: Object,
|
||||
dataType: {
|
||||
@ -116,8 +119,9 @@ tf-image-dashboard displays a dashboard that loads images from a TensorFlow run.
|
||||
},
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
TF.Dashboard.DashboardBehavior("images"),
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
attached: function() {
|
||||
this.async(function() {
|
||||
|
@ -190,9 +190,14 @@ contains vz-line-charts embedded inside tf-panes-helper's.
|
||||
</template>
|
||||
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfScalarDashboard = Polymer({
|
||||
is: "tf-scalar-dashboard",
|
||||
factoryImpl: function(backend, router) {
|
||||
this.backend = backend;
|
||||
this.router = router;
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("scalars"),
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
|
@ -57,9 +57,9 @@ allows the user to toggle between various dashboards.
|
||||
<div id="toolbar-content">
|
||||
<div class="toolbar-title">TensorBoard</div>
|
||||
<paper-tabs selected="{{modeIndex}}" noink class="tabs" id="tabs">
|
||||
<template is="dom-repeat" items="[[tabs]]">
|
||||
<template is="dom-if" if="[[_isTabEnabled(item)]]">
|
||||
<paper-tab data-mode="[[item]]">[[item]]</paper-tab>
|
||||
<template is="dom-repeat" items="[[_dashboards]]">
|
||||
<template is="dom-if" if="[[_isTabEnabled(item.name)]]">
|
||||
<paper-tab data-mode="[[item.name]]">[[item.name]]</paper-tab>
|
||||
</template>
|
||||
</template>
|
||||
</paper-tabs>
|
||||
@ -82,67 +82,7 @@ allows the user to toggle between various dashboards.
|
||||
</div>
|
||||
</paper-toolbar>
|
||||
|
||||
<div id="content" class="fit">
|
||||
<content id="injected-overview"></content>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsScalars(mode)]]">
|
||||
<tf-scalar-dashboard
|
||||
id="scalars"
|
||||
backend="[[_backend]]"
|
||||
router="[[router]]"
|
||||
></tf-scalar-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsImages(mode)]]">
|
||||
<tf-image-dashboard
|
||||
id="images"
|
||||
backend="[[_backend]]"
|
||||
></tf-image-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsAudio(mode)]]">
|
||||
<tf-audio-dashboard
|
||||
id="audio"
|
||||
backend="[[_backend]]"
|
||||
></tf-audio-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsGraphs(mode)]]">
|
||||
<tf-graph-dashboard
|
||||
id="graphs"
|
||||
backend="[[_backend]]"
|
||||
debugger-data-enabled="[[_debuggerDataEnabled]]"
|
||||
></tf-graph-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsDistributions(mode)]]">
|
||||
<tf-distribution-dashboard
|
||||
id="distributions"
|
||||
backend="[[_backend]]"
|
||||
></tf-distribution-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsHistograms(mode)]]">
|
||||
<tf-histogram-dashboard
|
||||
id="histograms"
|
||||
backend="[[_backend]]"
|
||||
></tf-histogram-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsEmbeddings(mode)]]">
|
||||
<vz-projector-dashboard
|
||||
id="projector"
|
||||
route-prefix="/data/plugin/projector">
|
||||
</vz-projector-dashboard>
|
||||
</template>
|
||||
|
||||
<template is="dom-if" if="[[_modeIsText(mode)]]">
|
||||
<tf-text-dashboard
|
||||
id="text"
|
||||
backend="[[_backend]]">
|
||||
</tf-text-dashboard>
|
||||
</template>
|
||||
</div>
|
||||
<div id="content" class="fit"></div>
|
||||
</paper-header-panel>
|
||||
|
||||
<style>
|
||||
@ -233,16 +173,21 @@ allows the user to toggle between various dashboards.
|
||||
return match && match.length == 1;
|
||||
},
|
||||
},
|
||||
_dashboards: {
|
||||
type: Array,
|
||||
computed: "_makeDashboardList(_backend, router, _debuggerDataEnabled)",
|
||||
},
|
||||
// Maps dashboard name to dashboard object.
|
||||
_dashboardMapping: {
|
||||
type: Object,
|
||||
computed: "_makeDashboardMapping(_dashboards)",
|
||||
},
|
||||
// Which tab is selected (scalars, graph, images etc).
|
||||
mode: {
|
||||
type: String,
|
||||
computed: '_getModeFromIndex(modeIndex)',
|
||||
notify: true,
|
||||
},
|
||||
tabs: {
|
||||
type: Array,
|
||||
readOnly: true,
|
||||
value: TF.Globals.TABS,
|
||||
observer: '_modeChanged',
|
||||
},
|
||||
// If this is set to a string, TensorBoard will switch to "demo mode"
|
||||
// and attempt to load serialized json data from that directory. You can
|
||||
@ -267,7 +212,7 @@ allows the user to toggle between various dashboards.
|
||||
return true;
|
||||
},
|
||||
_getModeFromIndex: function(modeIndex) {
|
||||
var mode = this.tabs[modeIndex];
|
||||
var mode = this._dashboards[modeIndex].name;
|
||||
TF.URIStorage.setString(TF.URIStorage.TAB, mode);
|
||||
return mode;
|
||||
},
|
||||
@ -279,34 +224,10 @@ allows the user to toggle between various dashboards.
|
||||
return new TF.Backend.Backend(router);
|
||||
},
|
||||
_isReloadDisabled: function(mode) {
|
||||
return !this._debuggerDataEnabled && this._modeIsGraphs(mode);
|
||||
},
|
||||
_modeIsScalars: function(mode) {
|
||||
return mode === "scalars";
|
||||
},
|
||||
_modeIsImages: function(mode) {
|
||||
return mode === "images";
|
||||
},
|
||||
_modeIsAudio: function(mode) {
|
||||
return mode === "audio";
|
||||
},
|
||||
_modeIsGraphs: function(mode) {
|
||||
return mode === "graphs";
|
||||
},
|
||||
_modeIsEmbeddings: function(mode) {
|
||||
return mode === "embeddings";
|
||||
},
|
||||
_modeIsDistributions: function(mode) {
|
||||
return mode === "distributions";
|
||||
},
|
||||
_modeIsHistograms: function(mode) {
|
||||
return mode === "histograms";
|
||||
},
|
||||
_modeIsText: function(mode) {
|
||||
return mode === "text";
|
||||
return !this._debuggerDataEnabled && mode == 'graphs';
|
||||
},
|
||||
selectedDashboard: function() {
|
||||
var dashboard = this.$$("#" + this.mode);
|
||||
var dashboard = this._dashboardMapping[this.mode];
|
||||
if (dashboard == null) {
|
||||
throw new Error(`Unable to find dashboard for mode: ${this.mode}`);
|
||||
}
|
||||
@ -320,24 +241,67 @@ allows the user to toggle between various dashboards.
|
||||
this._getModeFromHash();
|
||||
}.bind(this));
|
||||
},
|
||||
_makeDashboardList: function(backend, router, debuggerDataEnabled) {
|
||||
if (!backend || !router) {
|
||||
// The dashboards require these entities. We are not ready to construct dashboards.
|
||||
return null;
|
||||
}
|
||||
|
||||
return [
|
||||
new TF.Dashboard.TfScalarDashboard(backend, router),
|
||||
new TF.Dashboard.TfImageDashboard(backend),
|
||||
new TF.Dashboard.TfAudioDashboard(backend),
|
||||
new TF.Dashboard.TfGraphDashboard(backend, debuggerDataEnabled),
|
||||
new TF.Dashboard.TfDistributionDashboard(backend),
|
||||
new TF.Dashboard.TfHistogramDashboard(backend),
|
||||
new TF.Dashboard.VzProjectorDashboard('/data/plugin/projector'),
|
||||
new TF.Dashboard.TfTextDashboard(backend),
|
||||
];
|
||||
},
|
||||
_makeDashboardMapping: function(dashboards) {
|
||||
if (!dashboards) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let mapping = {};
|
||||
dashboards.forEach(function(dashboard) {
|
||||
mapping[dashboard.name] = dashboard;
|
||||
});
|
||||
return mapping;
|
||||
},
|
||||
_getModeFromHash: function() {
|
||||
var tabName = TF.URIStorage.getString(TF.URIStorage.TAB);
|
||||
var modeIndex = this.tabs.indexOf(tabName);
|
||||
if (modeIndex == -1 && this.modeIndex == null) {
|
||||
var modeIndex;
|
||||
for (var i = 0; i < this._dashboards.length; i++) {
|
||||
if (this._dashboards[i].name == tabName) {
|
||||
modeIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (modeIndex === undefined && this.modeIndex == null) {
|
||||
// Select the first tab as default.
|
||||
this.set('modeIndex', 0);
|
||||
}
|
||||
if (modeIndex != -1 && modeIndex != this.modeIndex) {
|
||||
if (modeIndex !== undefined && modeIndex != this.modeIndex) {
|
||||
this.set('modeIndex', modeIndex);
|
||||
}
|
||||
},
|
||||
_modeChanged: function(mode) {
|
||||
let currentDashboard = this.$.content.firstChild;
|
||||
if (currentDashboard) {
|
||||
this.$.content.removeChild(currentDashboard);
|
||||
}
|
||||
|
||||
if (!mode || !this._dashboardMapping) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Append the new dashboard.
|
||||
const newDashboard = this.selectedDashboard();
|
||||
this.$.content.appendChild(newDashboard);
|
||||
},
|
||||
reload: function() {
|
||||
if (this._modeIsEmbeddings(this.mode)) {
|
||||
return;
|
||||
}
|
||||
if (!this._debuggerDataEnabled && this._modeIsGraphs(this.mode)) {
|
||||
return;
|
||||
}
|
||||
this.selectedDashboard().reload();
|
||||
},
|
||||
openSettings: function() {
|
||||
|
@ -82,8 +82,11 @@ tf-text-dashboard displays a dashboard that loads texts from a TensorFlow run.
|
||||
</style>
|
||||
</template>
|
||||
<script>
|
||||
Polymer({
|
||||
TF.Dashboard.TfTextDashboard = Polymer({
|
||||
is: "tf-text-dashboard",
|
||||
factoryImpl: function(backend) {
|
||||
this.backend = backend;
|
||||
},
|
||||
properties: {
|
||||
backend: Object,
|
||||
dataType: {
|
||||
@ -92,15 +95,15 @@ tf-text-dashboard displays a dashboard that loads texts from a TensorFlow run.
|
||||
},
|
||||
},
|
||||
behaviors: [
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
TF.Dashboard.DashboardBehavior("text"),
|
||||
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
|
||||
TF.Backend.Behavior,
|
||||
],
|
||||
attached: function() {
|
||||
this.async(function() {
|
||||
this.fire("rendered");
|
||||
});
|
||||
},
|
||||
|
||||
});
|
||||
</script>
|
||||
</dom-module>
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
-->
|
||||
|
||||
<link rel="import" href="../polymer/polymer.html">
|
||||
<link rel="import" href="../tf-dashboard-common/tf-dashboard.html">
|
||||
<link rel="import" href="../tf-dashboard-common/tf-no-data-warning.html">
|
||||
<link rel="import" href="vz-projector.html">
|
||||
|
||||
@ -37,18 +38,33 @@ limitations under the License.
|
||||
</template>
|
||||
<script>
|
||||
(function() {
|
||||
Polymer({
|
||||
TF.Dashboard.VzProjectorDashboard = Polymer({
|
||||
is: 'vz-projector-dashboard',
|
||||
factoryImpl: function(routePrefix) {
|
||||
this.routePrefix = routePrefix;
|
||||
},
|
||||
properties: {
|
||||
dataNotFound: Boolean,
|
||||
routePrefix: String
|
||||
},
|
||||
ready() {
|
||||
var self = this;
|
||||
d3.json(this.routePrefix + '/runs', function(err, runs) {
|
||||
self.dataNotFound = (runs.length === 0);
|
||||
behaviors: [
|
||||
TF.Dashboard.DashboardBehavior("embeddings"),
|
||||
],
|
||||
observers: [
|
||||
"_routePrefixSet(routePrefix)",
|
||||
],
|
||||
reload: function() {
|
||||
// Do not reload the embedding projector. Reloading could take a long time.
|
||||
},
|
||||
_routePrefixSet: function(routePrefix) {
|
||||
if (routePrefix === undefined) {
|
||||
// The route prefix has not been set yet.
|
||||
return;
|
||||
}
|
||||
d3.json(routePrefix + '/runs', (err, runs) => {
|
||||
this.set('dataNotFound', runs.length === 0);
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
})();
|
||||
</script>
|
||||
|
@ -1,55 +0,0 @@
|
||||
# Description:
|
||||
# TensorBoard plugin for interacting with tfdbg, the TensorFlow debugger
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
## TensorFlow Debugger Plugiin ##
|
||||
py_library(
|
||||
name = "debugger_plugin",
|
||||
srcs = ["debugger_plugin.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/tensorboard/backend:http_util",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_file_loader",
|
||||
"//tensorflow/tensorboard/plugins:base_plugin",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "debugger_plugin_test",
|
||||
size = "small",
|
||||
srcs = ["debugger_plugin_test.py"],
|
||||
main = "debugger_plugin_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":debugger_plugin",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/tensorboard/backend:application",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_multiplexer",
|
||||
"//third_party/py/numpy",
|
||||
"@org_pocoo_werkzeug//:werkzeug",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
[
|
||||
"*",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,355 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The plugin for serving data from a TensorFlow debugger."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from werkzeug import wrappers
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.tensorboard.backend import http_util
|
||||
from tensorflow.tensorboard.backend.event_processing import event_accumulator
|
||||
from tensorflow.tensorboard.backend.event_processing import event_file_loader
|
||||
from tensorflow.tensorboard.plugins import base_plugin
|
||||
|
||||
# The prefix of routes provided by this plugin.
|
||||
_PLUGIN_PREFIX_ROUTE = 'debugger'
|
||||
|
||||
# HTTP routes.
|
||||
_HEALTH_PILLS_ROUTE = '/health_pills'
|
||||
|
||||
# The POST key of HEALTH_PILLS_ROUTE for a JSON list of node names.
|
||||
_NODE_NAMES_POST_KEY = 'node_names'
|
||||
|
||||
# The POST key of HEALTH_PILLS_ROUTE for the run to retrieve health pills for.
|
||||
_RUN_POST_KEY = 'run'
|
||||
|
||||
# The default run to retrieve health pills for.
|
||||
_DEFAULT_RUN = '.'
|
||||
|
||||
# The POST key of HEALTH_PILLS_ROUTE for the specific step to retrieve health
|
||||
# pills for.
|
||||
_STEP_POST_KEY = 'step'
|
||||
|
||||
# A glob pattern for files containing debugger-related events.
|
||||
_DEBUGGER_EVENTS_GLOB_PATTERN = 'events.debugger*'
|
||||
|
||||
|
||||
class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
"""TensorFlow Debugger plugin. Receives requests for debugger-related data.
|
||||
|
||||
That data could include health pills, which unveil the status of tensor
|
||||
values.
|
||||
"""
|
||||
|
||||
plugin_name = _PLUGIN_PREFIX_ROUTE
|
||||
|
||||
def get_plugin_apps(self, multiplexer, logdir):
|
||||
"""Obtains a mapping between routes and handlers. Stores the logdir.
|
||||
|
||||
Args:
|
||||
multiplexer: The EventMultiplexer that provides TB data.
|
||||
logdir: The logdir string - the directory of events files.
|
||||
|
||||
Returns:
|
||||
A mapping between routes and handlers (functions that respond to
|
||||
requests).
|
||||
"""
|
||||
self._event_multiplexer = multiplexer
|
||||
self._logdir = logdir
|
||||
return {
|
||||
_HEALTH_PILLS_ROUTE: self._serve_health_pills_handler,
|
||||
}
|
||||
|
||||
def is_active(self):
|
||||
"""Determines whether this plugin is active.
|
||||
|
||||
This plugin is active if any health pills information is present for any
|
||||
run. This method must be called only after get_plugin_apps has been called.
|
||||
|
||||
Returns:
|
||||
A boolean. Whether this plugin is active.
|
||||
"""
|
||||
for run_name in self._event_multiplexer.Runs():
|
||||
if self._event_multiplexer.GetOpsWithHealthPills(run_name):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@wrappers.Request.application
|
||||
def _serve_health_pills_handler(self, request):
|
||||
"""A (wrapped) werkzeug handler for serving health pills.
|
||||
|
||||
Accepts POST requests and responds with health pills. The request accepts
|
||||
several POST parameters:
|
||||
|
||||
node_names: (required string) A JSON-ified list of node names for which
|
||||
the client would like to request health pills.
|
||||
run: (optional string) The run to retrieve health pills for. Defaults to
|
||||
'.'. This data is sent via POST (not GET) since URL length is limited.
|
||||
step: (optional integer): The session run step for which to
|
||||
retrieve health pills. If provided, the handler reads the health pills
|
||||
of that step from disk (which is slow) and produces a response with
|
||||
only health pills at that step. If not provided, the handler returns a
|
||||
response with health pills at all steps sampled by the event
|
||||
multiplexer (the fast path). The motivation here is that, sometimes,
|
||||
one desires to examine health pills at a specific step (to say find
|
||||
the first step that causes a model to blow up with NaNs).
|
||||
get_plugin_apps must be called before this slower feature is used
|
||||
because that method passes the logdir (directory path) to this plugin.
|
||||
|
||||
This handler responds with a JSON-ified object mapping from node names to a
|
||||
list (of size 1) of health pill event objects, each of which has these
|
||||
properties.
|
||||
|
||||
{
|
||||
'wall_time': float,
|
||||
'step': int,
|
||||
'node_name': string,
|
||||
'output_slot': int,
|
||||
# A list of 12 floats that summarizes the elements of the tensor.
|
||||
'value': float[],
|
||||
}
|
||||
|
||||
Node names for which there are no health pills to be found are excluded from
|
||||
the mapping.
|
||||
|
||||
Args:
|
||||
request: The request issued by the client for health pills.
|
||||
|
||||
Returns:
|
||||
A werkzeug BaseResponse object.
|
||||
"""
|
||||
if request.method != 'POST':
|
||||
logging.error(
|
||||
'%s requests are forbidden by the debugger plugin.', request.method)
|
||||
return wrappers.Response(status=405)
|
||||
|
||||
if _NODE_NAMES_POST_KEY not in request.form:
|
||||
logging.error(
|
||||
'The %r POST key was not found in the request for health pills.',
|
||||
_NODE_NAMES_POST_KEY)
|
||||
return wrappers.Response(status=400)
|
||||
|
||||
jsonified_node_names = request.form[_NODE_NAMES_POST_KEY]
|
||||
try:
|
||||
node_names = json.loads(jsonified_node_names)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Different JSON libs raise different exceptions, so we just do a
|
||||
# catch-all here. This problem is complicated by how Tensorboard might be
|
||||
# run in many different environments, as it is open-source.
|
||||
logging.error('Could not decode node name JSON string %r: %s',
|
||||
jsonified_node_names, e)
|
||||
return wrappers.Response(status=400)
|
||||
|
||||
if not isinstance(node_names, list):
|
||||
logging.error('%r is not a JSON list of node names:',
|
||||
jsonified_node_names)
|
||||
return wrappers.Response(status=400)
|
||||
|
||||
run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN)
|
||||
step_string = request.form.get(_STEP_POST_KEY, None)
|
||||
if step_string is None:
|
||||
# Use all steps sampled by the event multiplexer (Relatively fast).
|
||||
mapping = self._obtain_sampled_health_pills(run, node_names)
|
||||
else:
|
||||
# Read disk to obtain the health pills for that step (Relatively slow).
|
||||
# Make sure that the directory for the run exists.
|
||||
# Determine the directory of events file to read.
|
||||
events_directory = self._logdir
|
||||
if run != _DEFAULT_RUN:
|
||||
# Use the directory for the specific run.
|
||||
events_directory = os.path.join(events_directory, run)
|
||||
|
||||
step = int(step_string)
|
||||
try:
|
||||
mapping = self._obtain_health_pills_at_step(
|
||||
events_directory, node_names, step)
|
||||
except IOError as error:
|
||||
logging.error(
|
||||
'Error retrieving health pills for step %d: %s', step, error)
|
||||
return wrappers.Response(status=404)
|
||||
|
||||
# Convert event_accumulator.HealthPillEvents to JSON-able dicts.
|
||||
jsonable_mapping = {}
|
||||
for node_name, events in mapping.items():
|
||||
jsonable_mapping[node_name] = [e._asdict() for e in events]
|
||||
return http_util.Respond(request, jsonable_mapping, 'application/json')
|
||||
|
||||
def _obtain_sampled_health_pills(self, run, node_names):
|
||||
"""Obtains the health pills for a run sampled by the event multiplexer.
|
||||
|
||||
This is much faster than the alternative path of reading health pills from
|
||||
disk.
|
||||
|
||||
Args:
|
||||
run: The run to fetch health pills for.
|
||||
node_names: A list of node names for which to retrieve health pills.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from node name to a list of
|
||||
event_accumulator.HealthPillEvents.
|
||||
"""
|
||||
mapping = {}
|
||||
for node_name in node_names:
|
||||
try:
|
||||
mapping[node_name] = self._event_multiplexer.HealthPills(run, node_name)
|
||||
except KeyError:
|
||||
logging.info('No health pills found for node %r.', node_name)
|
||||
continue
|
||||
|
||||
return mapping
|
||||
|
||||
def _obtain_health_pills_at_step(self, events_directory, node_names, step):
|
||||
"""Reads disk to obtain the health pills for a run at a specific step.
|
||||
|
||||
This could be much slower than the alternative path of just returning all
|
||||
health pills sampled by the event multiplexer. It could take tens of minutes
|
||||
to complete this call for large graphs for big step values (in the
|
||||
thousands).
|
||||
|
||||
Args:
|
||||
events_directory: The directory containing events for the desired run.
|
||||
node_names: A list of node names for which to retrieve health pills.
|
||||
step: The step to obtain health pills for.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from node name to a list of health pill objects (see
|
||||
docs for _serve_health_pills_handler for properties of those objects).
|
||||
|
||||
Raises:
|
||||
IOError: If no files with health pill events could be found.
|
||||
"""
|
||||
# Obtain all files with debugger-related events.
|
||||
pattern = os.path.join(events_directory, _DEBUGGER_EVENTS_GLOB_PATTERN)
|
||||
file_paths = glob.glob(pattern)
|
||||
|
||||
if not file_paths:
|
||||
raise IOError(
|
||||
'No events files found that matches the pattern %r.', pattern)
|
||||
|
||||
# Sort by name (and thus by timestamp).
|
||||
file_paths.sort()
|
||||
|
||||
mapping = collections.defaultdict(list)
|
||||
node_name_set = frozenset(node_names)
|
||||
|
||||
for file_path in file_paths:
|
||||
should_stop = self._process_health_pill_event(
|
||||
node_name_set, mapping, step, file_path)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
return mapping
|
||||
|
||||
def _process_health_pill_event(self, node_name_set, mapping, target_step,
|
||||
file_path):
|
||||
"""Creates health pills out of data in an event.
|
||||
|
||||
Creates health pills out of the event and adds them to the mapping.
|
||||
|
||||
Args:
|
||||
node_name_set: A set of node names that are relevant.
|
||||
mapping: The mapping from node name to event_accumulator.HealthPillEvents.
|
||||
This object may be destructively modified.
|
||||
target_step: The target step at which to obtain health pills.
|
||||
file_path: The path to the file with health pill events.
|
||||
|
||||
Returns:
|
||||
Whether we should stop reading events because future events are no longer
|
||||
relevant.
|
||||
"""
|
||||
events_loader = event_file_loader.EventFileLoader(file_path)
|
||||
for event in events_loader.Load():
|
||||
if not event.HasField('summary'):
|
||||
logging.warning('An event in a debugger events file lacks a summary.')
|
||||
continue
|
||||
|
||||
if event.step < target_step:
|
||||
# This event is not of the relevant step. We perform this check
|
||||
# first because the majority of events will be eliminated from
|
||||
# consideration by this check.
|
||||
continue
|
||||
|
||||
if event.step > target_step:
|
||||
# We have passed the relevant step. No need to read more events.
|
||||
return True
|
||||
|
||||
for value in event.summary.value:
|
||||
# Since we seek health pills for a specific step, this function
|
||||
# returns 1 health pill per node per step. The wall time is the
|
||||
# seconds since the epoch.
|
||||
health_pill = self._process_health_pill_value(
|
||||
node_name_set, event.wall_time, event.step, value)
|
||||
if not health_pill:
|
||||
continue
|
||||
mapping[health_pill.node_name].append(health_pill)
|
||||
|
||||
# Keep reading events.
|
||||
return False
|
||||
|
||||
def _process_health_pill_value(self, node_name_set, wall_time, step, value):
|
||||
"""Creates a dict containing various properties of a health pill.
|
||||
|
||||
Args:
|
||||
node_name_set: A set of node names that are relevant.
|
||||
wall_time: The wall time in seconds.
|
||||
step: The session run step of the event.
|
||||
value: The health pill value.
|
||||
|
||||
Returns:
|
||||
An event_accumulator.HealthPillEvent. Or None if one could not be created.
|
||||
"""
|
||||
if not value.HasField('tensor'):
|
||||
logging.warning(
|
||||
'An event in a debugger events file lacks a tensor value.')
|
||||
return None
|
||||
|
||||
if value.tag != event_accumulator.HEALTH_PILL_EVENT_TAG:
|
||||
logging.warning(
|
||||
('A debugger-related event lacks the %r tag. It instead has '
|
||||
'the %r tag.'), event_accumulator.HEALTH_PILL_EVENT_TAG, value.tag)
|
||||
return None
|
||||
|
||||
match = re.match(r'^(.*):(\d+):DebugNumericSummary$', value.node_name)
|
||||
if not match:
|
||||
logging.warning(
|
||||
('A event with a health pill has an invalid watch, (i.e., an '
|
||||
'unexpected debug op): %r'), value.node_name)
|
||||
return None
|
||||
|
||||
node_name = match.group(1)
|
||||
if node_name not in node_name_set:
|
||||
# This event is not relevant.
|
||||
return None
|
||||
|
||||
# Since we seek health pills for a specific step, this function
|
||||
# returns 1 health pill per node per step. The wall time is the
|
||||
# seconds since the epoch.
|
||||
return event_accumulator.HealthPillEvent(
|
||||
wall_time=wall_time,
|
||||
step=step,
|
||||
node_name=node_name,
|
||||
output_slot=int(match.group(2)),
|
||||
value=list(tensor_util.MakeNdarray(value.tensor)))
|
@ -1,300 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests the Tensorboard debugger data plugin."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
from werkzeug import test as werkzeug_test
|
||||
from werkzeug import wrappers
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.tensorboard.backend import application
|
||||
from tensorflow.tensorboard.backend.event_processing import event_multiplexer
|
||||
from tensorflow.tensorboard.plugins.debugger import debugger_plugin
|
||||
|
||||
|
||||
class DebuggerPluginTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Populate the log directory with debugger event for run '.'.
|
||||
self.log_dir = self.get_temp_dir()
|
||||
file_prefix = compat.as_bytes(os.path.join(self.log_dir, 'events.debugger'))
|
||||
writer = pywrap_tensorflow.EventsWriter(file_prefix)
|
||||
writer.WriteEvent(
|
||||
self._CreateEventWithDebugNumericSummary(
|
||||
op_name='layers/Matmul',
|
||||
output_slot=0,
|
||||
wall_time=42,
|
||||
step=2,
|
||||
list_of_values=[1, 2, 3]))
|
||||
writer.WriteEvent(
|
||||
self._CreateEventWithDebugNumericSummary(
|
||||
op_name='layers/Matmul',
|
||||
output_slot=1,
|
||||
wall_time=43,
|
||||
step=7,
|
||||
list_of_values=[4, 5, 6]))
|
||||
writer.WriteEvent(
|
||||
self._CreateEventWithDebugNumericSummary(
|
||||
op_name='logits/Add',
|
||||
output_slot=0,
|
||||
wall_time=1337,
|
||||
step=7,
|
||||
list_of_values=[7, 8, 9]))
|
||||
writer.WriteEvent(
|
||||
self._CreateEventWithDebugNumericSummary(
|
||||
op_name='logits/Add',
|
||||
output_slot=0,
|
||||
wall_time=1338,
|
||||
step=8,
|
||||
list_of_values=[10, 11, 12]))
|
||||
writer.Close()
|
||||
|
||||
# Populate the log directory with debugger event for run 'run_foo'.
|
||||
run_foo_directory = os.path.join(self.log_dir, 'run_foo')
|
||||
os.mkdir(run_foo_directory)
|
||||
file_prefix = compat.as_bytes(
|
||||
os.path.join(run_foo_directory, 'events.debugger'))
|
||||
writer = pywrap_tensorflow.EventsWriter(file_prefix)
|
||||
writer.WriteEvent(
|
||||
self._CreateEventWithDebugNumericSummary(
|
||||
op_name='layers/Variable',
|
||||
output_slot=0,
|
||||
wall_time=4242,
|
||||
step=42,
|
||||
list_of_values=[13, 14, 15]))
|
||||
writer.Close()
|
||||
|
||||
# Start a server that will receive requests and respond with health pills.
|
||||
self.multiplexer = event_multiplexer.EventMultiplexer({
|
||||
'.': self.log_dir,
|
||||
'run_foo': run_foo_directory,
|
||||
})
|
||||
self.plugin = debugger_plugin.DebuggerPlugin()
|
||||
wsgi_app = application.TensorBoardWSGIApp(
|
||||
self.log_dir, [self.plugin],
|
||||
self.multiplexer,
|
||||
reload_interval=0)
|
||||
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
|
||||
|
||||
def tearDown(self):
|
||||
# Remove the directory with debugger-related events files.
|
||||
shutil.rmtree(self.log_dir, ignore_errors=True)
|
||||
|
||||
def _CreateEventWithDebugNumericSummary(
|
||||
self, op_name, output_slot, wall_time, step, list_of_values):
|
||||
"""Creates event with a health pill summary.
|
||||
|
||||
Args:
|
||||
op_name: The name of the op to which a DebugNumericSummary was attached.
|
||||
output_slot: The numeric output slot for the tensor.
|
||||
wall_time: The numeric wall time of the event.
|
||||
step: The step of the event.
|
||||
list_of_values: A python list of values within the tensor.
|
||||
|
||||
Returns:
|
||||
A event_pb2.Event with a health pill summary.
|
||||
"""
|
||||
event = event_pb2.Event(step=step, wall_time=wall_time)
|
||||
value = event.summary.value.add(
|
||||
tag='__health_pill__',
|
||||
node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot))
|
||||
value.tensor.tensor_shape.dim.add(size=len(list_of_values))
|
||||
value.tensor.dtype = types_pb2.DT_DOUBLE
|
||||
value.tensor.tensor_content = np.array(
|
||||
list_of_values, dtype=np.float64).tobytes()
|
||||
return event
|
||||
|
||||
def _DeserializeResponse(self, byte_content):
|
||||
"""Deserializes byte content that is a JSON encoding.
|
||||
|
||||
Args:
|
||||
byte_content: The byte content of a JSON response.
|
||||
|
||||
Returns:
|
||||
The deserialized python object decoded from JSON.
|
||||
"""
|
||||
return json.loads(byte_content.decode('utf-8'))
|
||||
|
||||
def testHealthPillsRouteProvided(self):
|
||||
"""Tests that the plugin offers the route for requesting health pills."""
|
||||
apps = self.plugin.get_plugin_apps(self.multiplexer, self.log_dir)
|
||||
self.assertIn('/health_pills', apps)
|
||||
self.assertIsInstance(apps['/health_pills'], collections.Callable)
|
||||
|
||||
def testHealthPillsPluginIsActive(self):
|
||||
self.plugin.get_plugin_apps(self.multiplexer, self.log_dir)
|
||||
|
||||
# The multiplexer has sampled health pills.
|
||||
self.assertTrue(self.plugin.is_active())
|
||||
|
||||
def testHealthPillsPluginIsInactive(self):
|
||||
self.plugin.get_plugin_apps(
|
||||
event_multiplexer.EventMultiplexer({}), self.log_dir)
|
||||
|
||||
# The multiplexer lacks sampled health pills.
|
||||
self.assertFalse(self.plugin.is_active())
|
||||
|
||||
def testRequestHealthPillsForRunFoo(self):
|
||||
"""Tests that the plugin produces health pills for a specified run."""
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps(['layers/Variable', 'unavailable_node']),
|
||||
'run': 'run_foo',
|
||||
})
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertDictEqual({
|
||||
'layers/Variable': [{
|
||||
'wall_time': 4242,
|
||||
'step': 42,
|
||||
'node_name': 'layers/Variable',
|
||||
'output_slot': 0,
|
||||
'value': [13, 14, 15],
|
||||
}],
|
||||
}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testRequestHealthPillsForDefaultRun(self):
|
||||
"""Tests that the plugin produces health pills for the default '.' run."""
|
||||
# Do not provide a 'run' parameter in POST data.
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps(['logits/Add', 'unavailable_node']),
|
||||
})
|
||||
self.assertEqual(200, response.status_code)
|
||||
# The health pills for 'layers/Matmul' should not be included since the
|
||||
# request excluded that node name.
|
||||
self.assertDictEqual({
|
||||
'logits/Add': [
|
||||
{
|
||||
'wall_time': 1337,
|
||||
'step': 7,
|
||||
'node_name': 'logits/Add',
|
||||
'output_slot': 0,
|
||||
'value': [7, 8, 9],
|
||||
},
|
||||
{
|
||||
'wall_time': 1338,
|
||||
'step': 8,
|
||||
'node_name': 'logits/Add',
|
||||
'output_slot': 0,
|
||||
'value': [10, 11, 12],
|
||||
},
|
||||
],
|
||||
}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testGetRequestsUnsupported(self):
|
||||
"""Tests that GET requests are unsupported."""
|
||||
response = self.server.get('/data/plugin/debugger/health_pills')
|
||||
self.assertEqual(405, response.status_code)
|
||||
|
||||
def testRequestsWithoutProperPostKeyUnsupported(self):
|
||||
"""Tests that requests lacking the node_names POST key are unsupported."""
|
||||
response = self.server.post('/data/plugin/debugger/health_pills')
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
def testRequestsWithBadJsonUnsupported(self):
|
||||
"""Tests that requests with undecodable JSON are unsupported."""
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': 'some obviously non JSON text',
|
||||
})
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
def testRequestsWithNonListPostDataUnsupported(self):
|
||||
"""Tests that requests with loads lacking lists of ops are unsupported."""
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps({
|
||||
'this is a dict': 'and not a list.'
|
||||
}),
|
||||
})
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
def testFetchHealthPillsForSpecificStep(self):
|
||||
"""Tests that requesting health pills at a specific steps works.
|
||||
|
||||
This path may be slow in real life because it reads from disk.
|
||||
"""
|
||||
# Request health pills for these nodes at step 7 specifically.
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps(['logits/Add', 'layers/Matmul']),
|
||||
'step': 7
|
||||
})
|
||||
self.assertEqual(200, response.status_code)
|
||||
# The response should only include health pills at step 7.
|
||||
self.assertDictEqual({
|
||||
'logits/Add': [
|
||||
{
|
||||
'wall_time': 1337,
|
||||
'step': 7,
|
||||
'node_name': 'logits/Add',
|
||||
'output_slot': 0,
|
||||
'value': [7, 8, 9],
|
||||
},
|
||||
],
|
||||
'layers/Matmul': [
|
||||
{
|
||||
'wall_time': 43,
|
||||
'step': 7,
|
||||
'node_name': 'layers/Matmul',
|
||||
'output_slot': 1,
|
||||
'value': [4, 5, 6],
|
||||
},
|
||||
],
|
||||
}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testNoHealthPillsForSpecificStep(self):
|
||||
"""Tests that an empty mapping is returned for no health pills at a step."""
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps(['some/clearly/non-existent/op']),
|
||||
'step': 7
|
||||
})
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertDictEqual({}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testNoHealthPillsForOutOfRangeStep(self):
|
||||
"""Tests that an empty mapping is returned for an out of range step."""
|
||||
response = self.server.post(
|
||||
'/data/plugin/debugger/health_pills',
|
||||
data={
|
||||
'node_names': json.dumps(['logits/Add', 'layers/Matmul']),
|
||||
# This step higher than that of any event written to disk.
|
||||
'step': 42424242
|
||||
})
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertDictEqual({}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -32,7 +32,8 @@ from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.tensorboard.backend import application
|
||||
from tensorflow.tensorboard.backend.event_processing import event_file_inspector as efi
|
||||
|
||||
from tensorflow.tensorboard.plugins.projector import projector_plugin
|
||||
from tensorflow.tensorboard.plugins.text import text_plugin
|
||||
|
||||
# TensorBoard flags
|
||||
|
||||
@ -88,8 +89,18 @@ flags.DEFINE_string(
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def create_tb_app():
|
||||
"""Read the flags, and create a TensorBoard WSGI application."""
|
||||
def create_tb_app(plugins):
|
||||
"""Read the flags, and create a TensorBoard WSGI application.
|
||||
|
||||
Args:
|
||||
plugins: A list of plugins for TensorBoard to initialize.
|
||||
|
||||
Raises:
|
||||
ValueError: if a logdir is not specified.
|
||||
|
||||
Returns:
|
||||
A new TensorBoard WSGI application.
|
||||
"""
|
||||
if not FLAGS.logdir:
|
||||
raise ValueError('A logdir must be specified. Run `tensorboard --help` for '
|
||||
'details and examples.')
|
||||
@ -98,7 +109,8 @@ def create_tb_app():
|
||||
return application.standard_tensorboard_wsgi(
|
||||
logdir=logdir,
|
||||
purge_orphaned_data=FLAGS.purge_orphaned_data,
|
||||
reload_interval=FLAGS.reload_interval)
|
||||
reload_interval=FLAGS.reload_interval,
|
||||
plugins=plugins)
|
||||
|
||||
|
||||
def make_simple_server(tb_app, host, port):
|
||||
@ -184,7 +196,11 @@ def main(unused_argv=None):
|
||||
efi.inspect(FLAGS.logdir, event_file, FLAGS.tag)
|
||||
return 0
|
||||
else:
|
||||
tb = create_tb_app()
|
||||
plugins = [
|
||||
projector_plugin.ProjectorPlugin(),
|
||||
text_plugin.TextPlugin(),
|
||||
]
|
||||
tb = create_tb_app(plugins)
|
||||
run_simple_server(tb)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
24
tensorflow/tools/api/golden/BUILD
Normal file
24
tensorflow/tools/api/golden/BUILD
Normal file
@ -0,0 +1,24 @@
|
||||
# TensorFlow API backwards compatibility test goldens.
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/tools/api:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "api_golden",
|
||||
srcs = glob(["*.pbtxt"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -0,0 +1,24 @@
|
||||
path: "tensorflow.AggregationMethod"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.gradients_impl.AggregationMethod\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "ADD_N"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEFAULT"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "EXPERIMENTAL_ACCUMULATE_N"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "EXPERIMENTAL_TREE"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
@ -0,0 +1,108 @@
|
||||
path: "tensorflow.AttrValue.ListValue"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.ListValue\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "B_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNC_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "F_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "I_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "SHAPE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "S_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TENSOR_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TYPE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
120
tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt
Normal file
120
tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt
Normal file
@ -0,0 +1,120 @@
|
||||
path: "tensorflow.AttrValue"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.AttrValue\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "B_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNC_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "F_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "I_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "LIST_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "ListValue"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "PLACEHOLDER_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "SHAPE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "S_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TENSOR_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TYPE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
path: "tensorflow.AutoParallelOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.protobuf.rewriter_config_pb2.AutoParallelOptions\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "ENABLE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "NUM_REPLICAS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
path: "tensorflow.ConditionalAccumulatorBase"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.data_flow_ops.ConditionalAccumulatorBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "accumulator_ref"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'accumulator_ref\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "num_accumulated"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_step"
|
||||
argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
path: "tensorflow.ConditionalAccumulator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.data_flow_ops.ConditionalAccumulator\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.data_flow_ops.ConditionalAccumulatorBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "accumulator_ref"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply_grad"
|
||||
argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "num_accumulated"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_step"
|
||||
argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "take_grad"
|
||||
argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
path: "tensorflow.ConfigProto.DeviceCountEntry"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.protobuf.config_pb2.DeviceCountEntry\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "KEY_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "VALUE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
132
tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
Normal file
132
tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
Normal file
@ -0,0 +1,132 @@
|
||||
path: "tensorflow.ConfigProto"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.protobuf.config_pb2.ConfigProto\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEVICE_COUNT_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEVICE_FILTERS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DeviceCountEntry"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "GPU_OPTIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "GRAPH_OPTIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "INTER_OP_PARALLELISM_THREADS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOG_DEVICE_PLACEMENT_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "OPERATION_TIMEOUT_IN_MS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "PLACEMENT_PERIOD_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "RPC_OPTIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "SESSION_INTER_OP_THREAD_POOL_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "USE_PER_SESSION_THREADS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
77
tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt
Normal file
77
tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt
Normal file
@ -0,0 +1,77 @@
|
||||
path: "tensorflow.DType"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "as_datatype_enum"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "as_numpy_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "base_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_bool"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_complex"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_floating"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_integer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_numpy_compatible"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_quantized"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_unsigned"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "limits"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "max"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "min"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "real_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
37
tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt
Normal file
37
tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt
Normal file
@ -0,0 +1,37 @@
|
||||
path: "tensorflow.DeviceSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.device.DeviceSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "job"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "replica"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_string"
|
||||
argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_from"
|
||||
argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "parse_from_string"
|
||||
argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_string"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
25
tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt
Normal file
25
tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt
Normal file
@ -0,0 +1,25 @@
|
||||
path: "tensorflow.Dimension"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.Dimension\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "assert_is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
112
tensorflow/tools/api/golden/tensorflow.-event.pbtxt
Normal file
112
tensorflow/tools/api/golden/tensorflow.-event.pbtxt
Normal file
@ -0,0 +1,112 @@
|
||||
path: "tensorflow.Event"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.util.event_pb2.Event\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "FILE_VERSION_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "GRAPH_DEF_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOG_MESSAGE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "META_GRAPH_DEF_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "SESSION_LOG_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "STEP_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUMMARY_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TAGGED_RUN_METADATA_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "WALL_TIME_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
62
tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt
Normal file
62
tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt
Normal file
@ -0,0 +1,62 @@
|
||||
path: "tensorflow.FIFOQueue"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.data_flow_ops.FIFOQueue\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtypes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "names"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "queue_ref"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shapes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'fifo_queue\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "close"
|
||||
argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "dequeue"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "dequeue_many"
|
||||
argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "dequeue_up_to"
|
||||
argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enqueue"
|
||||
argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enqueue_many"
|
||||
argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_list"
|
||||
argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "size"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
path: "tensorflow.FixedLenFeature"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "default_value"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
path: "tensorflow.FixedLenSequenceFeature"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "allow_missing"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "default_value"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
path: "tensorflow.FixedLengthRecordReader"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.io_ops.FixedLengthRecordReader\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "reader_ref"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "supports_serialize"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "num_records_produced"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "num_work_units_completed"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "read"
|
||||
argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "read_up_to"
|
||||
argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "restore_state"
|
||||
argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "serialize_state"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
104
tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
Normal file
104
tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
Normal file
@ -0,0 +1,104 @@
|
||||
path: "tensorflow.GPUOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.protobuf.config_pb2.GPUOptions\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "ALLOCATOR_TYPE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "ALLOW_GROWTH_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEFERRED_DELETION_BYTES_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "POLLING_ACTIVE_DELAY_USECS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "POLLING_INACTIVE_DELAY_MSECS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "VISIBLE_DEVICE_LIST_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
92
tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt
Normal file
92
tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt
Normal file
@ -0,0 +1,92 @@
|
||||
path: "tensorflow.GraphDef"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.framework.graph_pb2.GraphDef\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "LIBRARY_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "NODE_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "VERSIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "VERSION_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
136
tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt
Normal file
136
tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt
Normal file
@ -0,0 +1,136 @@
|
||||
path: "tensorflow.GraphKeys"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.ops.GraphKeys\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "ACTIVATIONS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "ASSET_FILEPATHS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "BIASES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "CONCATENATED_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "COND_CONTEXT"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "EVAL_STEP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "GLOBAL_STEP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "GLOBAL_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "INIT_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOCAL_INIT_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOCAL_RESOURCES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOCAL_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOSSES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "MODEL_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "MOVING_AVERAGE_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "QUEUE_RUNNERS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "READY_FOR_LOCAL_INIT_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "READY_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "REGULARIZATION_LOSSES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "RESOURCES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "SAVEABLE_OBJECTS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "SAVERS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUMMARIES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUMMARY_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "TABLE_INITIALIZERS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "TRAINABLE_RESOURCE_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "TRAINABLE_VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "TRAIN_OP"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "UPDATE_OPS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "VARIABLES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "WEIGHTS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "WHILE_CONTEXT"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
112
tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt
Normal file
112
tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt
Normal file
@ -0,0 +1,112 @@
|
||||
path: "tensorflow.GraphOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.core.protobuf.config_pb2.GraphOptions\'>"
|
||||
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
|
||||
member {
|
||||
name: "BUILD_COST_MODEL_AFTER_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "BUILD_COST_MODEL_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "DESCRIPTOR"
|
||||
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "ENABLE_BFLOAT16_SENDRECV_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "ENABLE_RECV_SCHEDULING_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "Extensions"
|
||||
mtype: "<type \'getset_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "INFER_SHAPES_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "OPTIMIZER_OPTIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "PLACE_PRUNED_GRAPH_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "REWRITE_OPTIONS_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "TIMELINE_STEP_FIELD_NUMBER"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "ByteSize"
|
||||
}
|
||||
member_method {
|
||||
name: "Clear"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "ClearField"
|
||||
}
|
||||
member_method {
|
||||
name: "CopyFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "DiscardUnknownFields"
|
||||
}
|
||||
member_method {
|
||||
name: "FindInitializationErrors"
|
||||
}
|
||||
member_method {
|
||||
name: "FromString"
|
||||
}
|
||||
member_method {
|
||||
name: "HasExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "HasField"
|
||||
}
|
||||
member_method {
|
||||
name: "IsInitialized"
|
||||
}
|
||||
member_method {
|
||||
name: "ListFields"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFrom"
|
||||
}
|
||||
member_method {
|
||||
name: "MergeFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "ParseFromString"
|
||||
}
|
||||
member_method {
|
||||
name: "RegisterExtension"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializePartialToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SerializeToString"
|
||||
}
|
||||
member_method {
|
||||
name: "SetInParent"
|
||||
}
|
||||
member_method {
|
||||
name: "WhichOneof"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user