Branch 152550050 (#9059)

* Improve py_func error handling.

Automatically translate some python errors into corresponding TF errors at runtime.
Change: 152156821

* Update interaction with libpng so that we use the public API instead of
knowledge of the internal libpng data structures.
Change: 152167754

* TensorBoard plugins now contain their own name/route prefix.
Change: 152167807

* Passes trainable flag to separable_conv2d biases.
Change: 152170239

* Saving resource variables with a caching device.
Change: 152171539

* Drop loss from estimator_spec.eval_metric_ops, as required by core Estimator.
Change: 152179924

* sample_stats.percentile DOCFIX.
Change: 152182295

* Added a memory optimizer to grappler.
Change: 152184170

* Change default behavior of the tf runs selector:

- If there are fewer than 41 runs, enable them all by default
- If there are 41 runs or more, disable them all by default

This is in response to user complaints that having it enable only the first ten runs by default was confusing, because it was not obvious to users that some runs had been disabled.
However, it still solves the initial user complaint that having very many runs simultaneously enabled would lag the UI.

I also changed the "toggle all runs" button to try to turn everything off before turning everything on.
Also, I improved the logic for detecting when the runs selection is back in the default state, so that we can avoid generating long URI strings wherever possible.
Change: 152188948

* Autogenerated Change: Change TensorBoard TAG to 52
Change: 152189000

* Remove warning that only happening with config cuda.
Change: 152189205

* Make resource variable shared name consistent with non-resource variables.

Remove colocation constraint from resource variable cached value with the
variable itself.
Change: 152192203

* Add a way to specify the optimization order; refactor and add constant folding to meta optimizer.
Change: 152193646

* Backport fixes and improvements from external Keras.
Change: 152198296

* Merge changes from github.
Change: 152200430

* Go: Update generated wrapper functions for TensorFlow ops.
Change: 152200754

* Update ops-related pbtxt files.
Change: 152203174

* Make ImportGraphDef() work with functions.

In addition to modify graph_constructor.cc, this patch adds some other
functionality to enable importing fucntions:
* Ability to add FunctionDefLibraries to Graphs and
  FunctionLibraryDefinitions (in addition to existing functions)
* FunctionDefsEqual() utility function
Change: 152205258

* Expand contrib test to more than just test targets.
Change: 152206822

* Preserve graph version during optimization
Change: 152213262

* Exclude enter and exit nodes from shape refiner's constant folding.
Change: 152213637

* Allow reshape_mover and algebraic_simplifier to make multiple mutations, by avoiding the short-circuit
std::any_of.
Change: 152232810

* Fix dynamic_rnn transpose bug (can input/output non-3d tensors).

Also a few cleanups to RNN code.
Change: 152267628

* Fix flaky tests
Change: 152272801

* Add an auto parallelization grappler optimization pass.
Change: 152276787

* Change json.decode.JSONDecodeError to ValueError.  JSONDecodeError seems to be
the exception used in the simplejson module, not the json module.
Change: 152278012

* Internal change.
Change: 152281471

* [XLA] Force buffer sharing of separate while instructions.
Change: 152288540

* replica_device_setter should work for resource variables
Change: 152289915

* Fix ./configure script
1. Add %workspace% in .bazelrc file when using import statement
2. Write action_env into bazelrc file for required environment variables for OpenCL support
Change: 152290700

* Pointing a number of Tensorboard graph visualization-related help links to the new locations for the correspondent API documentation.
Change: 152293459

* Restore most of pull request #8606

Pull request #8606 added str(Label(...)) for most dependencies in
tensorflow.bzl, allowing most functions to be used from repositories which
include TensorFlow as a submodule.  Unfortunately, it broke when pulled into
Google and was removed in cl/152200430.  This CL restores the change, except
for two Android-only functions; these were the only problematic bits.
Change: 152297413

* Removed dead code in Estimator.
Change: 152297597

* Assert rank is at least equal to new_rank for `_sparse_inner_flatten`.
Change: 152303319

* Extend quantization ranges to include 0.0f.
Change: 152304380

* Remove Keras config file saving.
Change: 152306552

* API backwards compatibility tests.
Change: 152310869

* [TF:XLA] Add a test for an R3 -> R4 broadcast.
Change: 152313967

* Fix the problem that no enough placeholders for persistent tensor
batch delete

The deleter_key is always a device_name, hence there is only one
of it. Hence, we cannot delete >1 handles at one time.

In the fix, it creates delete placeholder on demand, the max
number of placeholders is _DEAD_HANDLES_THRESHOLD.
Change: 152322770

* [XLA] Add several reduction tests.
Change: 152323510

* Added the memory optimizer to the meta optimizer.
Change: 152323689

* Started a set of utilities to categorize op types
Change: 152329057

* Add AudioSpectrogram op to TensorFlow for audio feature generation
Change: 152332221

* Update ops-related pbtxt files.
Change: 152332812

* Automated rollback of change 152332221
Change: 152333917

* Call Py_CLEAR on dead fields during TF_RESOURCE-to-ndarray conversion
Change: 152338333

* [TF contrib seq2seq] Initial, incomplete implementation of beam search decoder.

**DOES NOT WORK, pushed for collaboration only**
Change: 152343927

* [XLA] Change HloPassPipeline to disallow Add* calls after Run.
Change: 152345578

* Automated rollback of change 152332812
Change: 152349057

* Remove all 64/32 bit compiler warnings from core/ops.
Change: 152353506

* libtensorflow.so: Don't export private symbols.

With this change, libtensorflow.so will only export
functions defined in c_api.h. This also results in
a decreased binary size of libtensorflow.so.

On Linux the decrease was from roughly 150MB to 67MB.
On OS X it was from roughly 101MB to 82MB.

Also fixes #8923
Change: 152366053

* Add Elu ops in XLA.
Change: 152383201

* Fixed test. ('broadcast_dims' has size 1)
Change: 152383633

* Add more detailed error message for rank assertion in _sparse_inner_flatten.
Change: 152397909

* tensor_bundle: propagrates errors related to directory creation.
Change: 152401909

* matrix_adjoint added to contrib/linalg/linear_operator_util
Change: 152404828

* Add an is_active method to plugins

This method determines whether a plugin is active. A plugin may be inactive if say it lacks data. This new is_active method allows us to add a route to TensorBoard noting which plugins are active. The frontend could then avoid querying routes of inactive plugins.
Change: 152406232

* Replace a gather op for shapes by a stack op so dilated convolutions can be
placed on GPU even with strict placing (before the gather went to CPU).
Change: 152411159

* [TF:XLA] Implement BatchToSpace, BatchToSpaceND, SpaceToBatch, SpaceToBatchND.
Fix crashes in core implementations of the same operators for zero-sized blocks.
Change: 152416903

* Estimator saves relative paths in checkpoint.
Change: 152420211

* Fix layers_test exception regex matching.
Change: 152422855

* Unhide bijectors. Correct TransformedDistribution docstring.
Change: 152424418

* Choosing a saner default for min_eval_frequency in the constructor for Experiment for the GCS file system, because the default of 1 causes performance problems.
Change: 152439984

* Inherit use_resource from scope for partitioned variables.
Change: 152442103

* Support quantized reshape in hexagon runtime
Change: 152445539

* tfdbg CLI: add command list_source (ls) + UI fixes and improvements

The new list_source (shorthand: ls) command lists Python source files responsible for constructing the nodes and tensors encountered in the run() call.

It divides the source files into two categories and list them separately.
1) files that are not part of the TensorFlow Python library, and
2) files that are a part of it.

The list contains information about how many nodes, tensors and dumps of tensors the files is responsible for. The file paths contain clickable links to the existing print_source/ps command.

The list_source/ls command supports filtering by file-path and node-name regex patterns.

UI fixes:
* Fixed inconsistent black vs. transparent background color that made the layout look messy on some terminal types. Now using the transparent color for default font color consistently.
* In the print_source command output, add clickable links to expand source lines and graph elements.
Change: 152446002

* tfcompile: Be a little more verbose about missing required flags.

Fixes #9014
Change: 152446338

* Disable failing test cases in pooling_ops_test.
Change: 152447322

* Register more types for tf.image_crop_and_resize(). Resolves #9020.
Change: 152448160

* Automated rollback of change 152439984
Change: 152450929

* Add a route to TensorBoard for fetching plugin names

Specifically, we add a /data/plugins_listing route to the TensorBoard application. This route responds with an object mapping the name of each initialized plugin to whether it is active.

This route could help the frontend avoid issuing requests to inactive plugins.

Ordered the listing of routes within application.py so there is a little more organization.

Refactored the test for application to use a fake plugin.
Change: 152451390

* Added the ability to retrieve the amount of usable gpu memory
Change: 152453470

* Allow to set session ConfigProto in RunConfig and use it in Estimator.
Change: 152454548

* Colocate ResourceVariable reads with their handles.
Change: 152455939

* tfdbg: update doc for new command list_source/ls
Change: 152456128

* Make rnn directions slightly easier to follow.
Change: 152456296

* Internal change
Change: 152458104

* Adds batch renormalization.

NOTE: if you use renormalization, you might want to use faster moving average updates, i.e. lower `decay` values.
Change: 152458872

* When using ImportGraphDef with a passed in ShapeRefiner, use the
producer version of the GraphDef when importing; the ShapeRefiner
may be initialized with a different graph_def_version, so we need
to be able to override it.

The test failed without the change to graph_constructor and passes with it.
The test uses a legacy graph that is supported (reduction shape).
Change: 152459169

* Allow any iterable for `export_strategies` arg.
Change: 152461826

* Log steps/sec every 100 steps in MonitoredSession, as before.
Change: 152465320

* Fixes documentation to note that the in case of ties the identity of the return value of ArgMin and ArgMaxis not guaranteed .
Change: 152465346

* Automated rollback of change 152465346
Change: 152465844

* Fix shape inference fn on _ParallelConcatStart.
Change: 152466076

* Fix getting started guide

Explain numerical differences in loss
fix one example to print
Change: 152466119

* Remove superfluous mode argument.
Change: 152467334

* Add a tool that converts HLO computations to tensorflow GraphDef which can be visualized on Tensorboard.

This CL defines basic tensorflow::OpDef for each HLO instruction/node. More attributes (e.g. shapes, colors) will be added in the future.
Change: 152477918

* [TF:XLA] Increase shard count of //third_party/tensorflow/compiler/tests:spacetobatch_test to reduce flakiness when built under ASAN.
Change: 152496244

* Make projector plugin backend read assets saved via the PluginAssets API.

At the same time, keep backwards compatibility with the old way of looking up assets.
Change: 152504793

* Move MNIST pointers to mirror hosted by the CVDF on Google Cloud.
Fixes: #9031
Change: 152504901

* Merge changes from github.
Change: 152508170

* Update API after changing default step couter frequency before.
Change: 152517535

* Move a few random op helper functions to header files

1. shape_inference::RandomShape
2. OpKernel::MakeShape(Tensor, TensorShape*)
Change: 152522156

* addresses the divide by zero bug
Change: 152522488

* Clarify doc on tf.assign.
Change: 152523909

* Sparse adam for resource variables.
Change: 152525327

* Automated rollback of change 152310869
Change: 152528732

* Add an env_var tf_sync_on_finish_bool that block until device has finished all queued operations in a step if true.
Change: 152533676

* Add more node attributes for HloInstruction on Tensorboard e.g. shape and layout etc.
Change: 152534472

* Add tf.complex64 GPU support to tf.gather.

Also add ldg specializations for std::complex.
Change: 152537848

* Formatting changes
Change: 152544842

* Upgrade TensorBoard TypeScript to 2.2.1

See also: #8326
Change: 152545950

* TEST:  Getting reasonable test sizes on linalg library, removing need for
sharding.
Change: 152546409

* Disabling _testSourceUtilModuleReturnsTrue as its causing opensource issues.
Change: 152548721

* Fix race due to unsafe buffer forwarding in maxpooling second order gradients added in #6664.
Re-enable previously flaky tests.
Clean up a few minor things in maxpooling_op_gpu.cu.cc
Change: 152550050
This commit is contained in:
Rohan Jain 2017-04-07 18:04:26 -07:00 committed by GitHub
parent f8dce81aea
commit 52dcb2590b
176 changed files with 5882 additions and 882 deletions

6
configure vendored
View File

@ -56,7 +56,7 @@ rm -f .tf_configure.bazelrc
touch .tf_configure.bazelrc
touch .bazelrc
sed_hyphen_i "/tf_configure/d" .bazelrc
echo "import .tf_configure.bazelrc" >> .bazelrc
echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc
# Delete any leftover BUILD files from the Makefile build, which would interfere
# with Bazel parsing.
@ -284,6 +284,7 @@ export TF_NEED_CUDA
write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA"
export TF_NEED_OPENCL
write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL"
if [ "$TF_NEED_CUDA" == "1" ]; then
while [[ "$TF_CUDA_CLANG" == "" ]]; do
@ -547,6 +548,7 @@ while true; do
fi
if [ -e "$HOST_CXX_COMPILER" ]; then
export HOST_CXX_COMPILER
write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER"
break
fi
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
@ -570,6 +572,7 @@ while true; do
fi
if [ -e "$HOST_C_COMPILER" ]; then
export HOST_C_COMPILER
write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER"
break
fi
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
@ -600,6 +603,7 @@ while true; do
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
export COMPUTECPP_TOOLKIT_PATH
write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH"
break
fi
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"

View File

@ -351,9 +351,24 @@ filegroup(
# -------------------------------------------
cc_binary(
name = "libtensorflow.so",
linkopts = select({
"//tensorflow:darwin": [
"-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file
"//tensorflow/c:exported_symbols.lds",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-z defs",
"-s",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"//tensorflow/c:version_script.lds",
],
}),
linkshared = 1,
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
"//tensorflow/core:tensorflow",
],
)

View File

@ -45,6 +45,14 @@ tf_cuda_library(
}),
)
exports_files(
[
"version_script.lds",
"exported_symbols.lds",
],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],

View File

@ -0,0 +1 @@
_TF_*

View File

@ -0,0 +1,9 @@
VERS_1.0 {
# Export symbols in c_api.h.
global:
TF_*;
# Hide everything else.
local:
*;
};

View File

@ -52,7 +52,8 @@ const char kUsageHeader[] =
"header file that gives access to the functionality in the object file.\n"
"A typical invocation looks like this:\n"
"\n"
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt "
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& kind, const string& fname,
@ -73,6 +74,9 @@ void ParseTensorId(const string& name, TensorId* id) {
Status Main(const MainFlags& flags) {
// Process config.
Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
@ -85,6 +89,9 @@ Status Main(const MainFlags& flags) {
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
std::unique_ptr<Graph> graph;
@ -101,6 +108,9 @@ Status Main(const MainFlags& flags) {
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
StringPiece(obj.data(), obj.size())));
HeaderOpts header_opts;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
&header_opts.namespaces));
string header;
@ -131,12 +141,16 @@ int main(int argc, char** argv) {
QCHECK(parsed_flags_ok) << "\n" << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1 && !flags.config.empty() &&
(flags.dump_fetch_nodes ||
(!flags.graph.empty() && !flags.entry_point.empty())))
<< "\n"
<< usage;
TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
"other than flags\n\n"
<< usage;
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
<< usage;
return 1;
} else {
TF_QCHECK_OK(status);
}
return 0;
}

View File

@ -305,6 +305,20 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "spacetobatch_op_test",
size = "medium",
srcs = ["spacetobatch_op_test.py"],
shard_count = 3,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "ternary_ops_test",
size = "small",

View File

@ -107,6 +107,12 @@ class BinaryOpsTest(XLATestCase):
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([-75, -48, -21, 0], dtype=dtype))
self._testBinary(
gen_nn_ops._elu_grad,
np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),

View File

@ -218,12 +218,11 @@ class OpTest : public ::testing::Test {
static constexpr int kDefaultMaxRank = 5;
static constexpr int64 kDefaultMaxDimensionSize = 20LL;
// Returns a random dimension size.
// Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
// Returns a random shape. The tensor has rank in the range [min_rank,
// max_rank).
// Each dimension has size [0, kDefaultMaxDimensionSize].
// max_rank). Each dimension has size [min_size, max_size).
std::vector<int64> RandomDims(int min_rank = 0,
int max_rank = kDefaultMaxRank,
int64 min_size = 0,
@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test";
return;
}
for (const Tensor& expected : expected_outputs) {
VLOG(1) << "Expected: " << expected.DebugString();
}
VLOG(1) << "Running test graph";
TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs));
@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) {
});
}
TEST_F(OpTest, BatchToSpace) {
Repeatedly([this]() {
const int num_block_dims = 2;
std::vector<int64> block_dims =
RandomDims(num_block_dims, num_block_dims, 0, 5);
int64 block_size = RandomDim(0, 4);
std::vector<int64> input_dims(1 + num_block_dims + 1);
input_dims[0] = RandomDim();
for (int i = 0; i < num_block_dims; ++i) {
input_dims[0] *= block_size;
input_dims[1 + i] = block_dims[i];
}
input_dims[1 + num_block_dims] = RandomDim();
std::vector<int64> crop_vals;
std::uniform_int_distribution<int> distribution(0, 4);
for (int i = 0; i < num_block_dims; ++i) {
// Chooses crop values; does not always choose legal values.
crop_vals.push_back(distribution(generator()));
crop_vals.push_back(distribution(generator()));
}
Tensor crops;
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
.Input(RandomTensor(DT_FLOAT, input_dims))
.Input(crops)
.Attr("T", DT_FLOAT)
.Attr("block_size", block_size));
});
}
TEST_F(OpTest, BatchToSpaceND) {
Repeatedly([this]() {
std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
int num_block_dims = block_dims.size();
std::vector<int64> remaining_dims = RandomDims(0, 3);
std::vector<int64> block_multipliers =
RandomDims(block_dims.size(), block_dims.size(), 0, 4);
std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
input_dims[0] = RandomDim();
for (int i = 0; i < num_block_dims; ++i) {
input_dims[0] *= block_dims[i];
}
std::copy(block_multipliers.begin(), block_multipliers.end(),
input_dims.begin() + 1);
std::copy(remaining_dims.begin(), remaining_dims.end(),
input_dims.begin() + 1 + num_block_dims);
std::vector<int64> crop_vals;
std::uniform_int_distribution<int> distribution(0, 3);
for (int i = 0; i < num_block_dims; ++i) {
// Chooses crop values; does not always choose legal values.
crop_vals.push_back(distribution(generator()));
crop_vals.push_back(distribution(generator()));
}
Tensor crops;
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("BatchToSpaceND")
.Input(RandomTensor(DT_FLOAT, input_dims))
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(crops)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, BiasAdd) {
Repeatedly([this]() {
auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
@ -1214,6 +1289,23 @@ TEST_F(OpTest, DynamicStitch) {
});
}
TEST_F(OpTest, Elu) {
Repeatedly([this]() {
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, EluGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
.Input(RandomTensor(DT_FLOAT, dims))
.Input(RandomTensor(DT_FLOAT, dims))
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
@ -2019,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) {
});
}
TEST_F(OpTest, SpaceToBatch) {
Repeatedly([this]() {
std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
const int num_block_dims = 2;
int64 block_size = RandomDim(0, 4);
std::vector<int64> input_dims(1 + num_block_dims + 1);
input_dims[0] = RandomDim();
for (int i = 0; i < num_block_dims; ++i) {
input_dims[1 + i] = block_dims[i] * block_size;
}
input_dims[1 + num_block_dims] = RandomDim();
std::vector<int64> padding_vals;
std::uniform_int_distribution<int> distribution(0, 7);
for (int i = 0; i < num_block_dims; ++i) {
int64 pad_before;
int64 pad_after;
do {
pad_before = distribution(generator());
pad_after = distribution(generator());
} while (pad_before + pad_after > input_dims[1 + i]);
input_dims[1 + i] -= pad_before + pad_after;
padding_vals.push_back(pad_before);
padding_vals.push_back(pad_after);
}
Tensor paddings;
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
.Input(RandomTensor(DT_FLOAT, input_dims))
.Input(paddings)
.Attr("T", DT_FLOAT)
.Attr("block_size", block_size));
});
}
TEST_F(OpTest, SpaceToBatchND) {
Repeatedly([this]() {
std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
int num_block_dims = block_dims.size();
std::vector<int64> remaining_dims = RandomDims(0, 3);
std::vector<int64> block_multipliers =
RandomDims(block_dims.size(), block_dims.size(), 0, 4);
std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
input_dims[0] = RandomDim();
for (int i = 0; i < num_block_dims; ++i) {
input_dims[1 + i] = block_dims[i] * block_multipliers[i];
}
std::copy(remaining_dims.begin(), remaining_dims.end(),
input_dims.begin() + 1 + num_block_dims);
std::vector<int64> padding_vals;
std::uniform_int_distribution<int> distribution(0, 7);
for (int i = 0; i < num_block_dims; ++i) {
int64 pad_before;
int64 pad_after;
do {
pad_before = distribution(generator());
pad_after = distribution(generator());
} while (pad_before + pad_after > input_dims[1 + i]);
input_dims[1 + i] -= pad_before + pad_after;
padding_vals.push_back(pad_before);
padding_vals.push_back(pad_after);
}
Tensor paddings;
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("SpaceToBatchND")
.Input(RandomTensor(DT_FLOAT, input_dims))
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(paddings)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SparseMatMul) {
Repeatedly([this]() {
int64 x = RandomDim();

View File

@ -0,0 +1,266 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for SpaceToBatch and BatchToSpace ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import test
def space_to_batch_direct(input_array, block_shape, paddings):
"""Direct Python implementation of space-to-batch conversion.
This is used for tests only.
Args:
input_array: N-D array
block_shape: 1-D array of shape [num_block_dims].
paddings: 2-D array of shape [num_block_dims, 2].
Returns:
Converted tensor.
"""
input_array = np.array(input_array)
block_shape = np.array(block_shape)
num_block_dims = len(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
padded = np.pad(input_array,
pad_width=([[0, 0]] + list(paddings) + [[0, 0]] *
(input_array.ndim - 1 - num_block_dims)),
mode="constant")
reshaped_padded_shape = [input_array.shape[0]]
output_shape = [input_array.shape[0] * np.prod(block_shape)]
for block_dim, block_shape_value in enumerate(block_shape):
reduced_size = padded.shape[block_dim + 1] // block_shape_value
reshaped_padded_shape.append(reduced_size)
output_shape.append(reduced_size)
reshaped_padded_shape.append(block_shape_value)
reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:])
output_shape.extend(input_array.shape[num_block_dims + 1:])
reshaped_padded = padded.reshape(reshaped_padded_shape)
permuted_reshaped_padded = np.transpose(reshaped_padded, (
list(np.arange(num_block_dims) * 2 + 2) + [0] +
list(np.arange(num_block_dims) * 2 + 1) + list(
np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims
* 2)))
return permuted_reshaped_padded.reshape(output_shape)
class SpaceToBatchTest(XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs):
with self.test_session() as sess, self.test_scope():
for dtype in self.float_types:
# outputs = space_to_batch(inputs)
placeholder = array_ops.placeholder(dtype)
x_tf = gen_array_ops._space_to_batch(
placeholder, paddings, block_size=block_size)
self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
# inputs = batch_to_space(outputs)
x_tf = gen_array_ops._batch_to_space(
placeholder, paddings, block_size=block_size)
self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
def _testOne(self, inputs, block_size, outputs):
paddings = np.zeros((2, 2), dtype=np.int32)
self._testPad(inputs, paddings, block_size, outputs)
# [1, 2, 2, 1] <-> [4, 1, 1, 1]
def testSmallInput2x2(self):
x_np = [[[[1], [2]], [[3], [4]]]]
block_size = 2
x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
self._testOne(x_np, block_size, x_out)
# [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1]
def testSmallInput2x2Pad1x0(self):
x_np = [[[[1], [2]], [[3], [4]]]]
paddings = np.array([[1, 0], [1, 0]], dtype=np.int32)
block_size = 3
x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]],
[[[3]]], [[[4]]]]
self._testPad(x_np, paddings, block_size, x_out)
# Test with depth larger than 1.
# [1, 2, 2, 3] <-> [4, 1, 1, 3]
def testDepthInput2x2(self):
x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]
block_size = 2
x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]]
self._testOne(x_np, block_size, x_out)
# Test for larger input dimensions.
# [1, 4, 4, 1] <-> [4, 2, 2, 1]
def testLargerInput2x2(self):
x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
[[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
block_size = 2
x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]],
[[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]]
self._testOne(x_np, block_size, x_out)
# Test with batch larger than 1.
# [2, 2, 4, 1] <-> [8, 1, 2, 1]
def testBatchInput2x2(self):
x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
[[[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
block_size = 2
x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]],
[[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]]
self._testOne(x_np, block_size, x_out)
# Tests for larger input spatial dimensions AND batch larger than 1, to ensure
# that elements are correctly laid out spatially and properly interleaved
# along the batch dimension.
# [2, 4, 4, 1] <-> [8, 2, 2, 1]
def testLargerInputBatch2x2(self):
x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
[[[17], [18], [19], [20]], [[21], [22], [23], [24]],
[[25], [26], [27], [28]], [[29], [30], [31], [32]]]]
x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]],
[[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]],
[[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]],
[[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]]
block_size = 2
self._testOne(x_np, block_size, x_out)
class SpaceToBatchNDTest(XLATestCase):
"""Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
def _testPad(self, inputs, block_shape, paddings, outputs):
block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
with self.test_session() as sess, self.test_scope():
for dtype in self.float_types:
placeholder = array_ops.placeholder(dtype)
# outputs = space_to_batch(inputs)
x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
# inputs = batch_to_space(outputs)
placeholder = array_ops.placeholder(dtype)
x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings)
self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
def _testDirect(self, input_shape, block_shape, paddings):
inputs = np.arange(np.prod(input_shape), dtype=np.float32)
inputs = inputs.reshape(input_shape)
self._testPad(inputs, block_shape, paddings,
space_to_batch_direct(inputs, block_shape, paddings))
def testZeroBlockDimsZeroRemainingDims(self):
self._testPad(
inputs=[1, 2],
block_shape=[],
paddings=[],
outputs=[1, 2],)
def testZeroBlockDimsOneRemainingDim(self):
self._testPad(
inputs=[[1, 2], [3, 4]],
block_shape=[],
paddings=[],
outputs=[[1, 2], [3, 4]])
# Same thing, but with a no-op block dim.
self._testPad(
inputs=[[1, 2], [3, 4]],
block_shape=[1],
paddings=[[0, 0]],
outputs=[[1, 2], [3, 4]])
def testZeroBlockDimsTwoRemainingDims(self):
self._testPad(
inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
block_shape=[],
paddings=[],
outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# Same thing, but with a no-op block dim.
self._testPad(
inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
block_shape=[1],
paddings=[[0, 0]],
outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# Same thing, but with two no-op block dims.
self._testPad(
inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
block_shape=[1, 1],
paddings=[[0, 0], [0, 0]],
outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
def testOneBlockDimZeroRemainingDims(self):
self._testPad(
inputs=[[1, 2, 3], [4, 5, 6]],
block_shape=[2],
paddings=[1, 0],
outputs=[[0, 2], [0, 5], [1, 3], [4, 6]])
def testOneBlockDimOneRemainingDim(self):
self._testPad(
inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]],
block_shape=[2],
paddings=[1, 0],
outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
[[4, 41], [6, 61]]])
def testDirect(self):
# Test with zero-size remaining dimension.
self._testDirect(
input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
# Test with zero-size blocked dimension.
self._testDirect(
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
# Test with padding up from zero size.
self._testDirect(
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]])
self._testDirect(
input_shape=[3, 3, 4, 5, 2],
block_shape=[3, 4, 2],
paddings=[[1, 2], [0, 0], [3, 0]])
self._testDirect(
input_shape=[3, 3, 4, 5, 2],
block_shape=[3, 4, 2, 2],
paddings=[[1, 2], [0, 0], [3, 0], [0, 0]])
self._testDirect(
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
block_shape=[1, 1, 3, 4, 2, 2],
paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]])
self._testDirect(
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
block_shape=[1, 1, 3, 4, 2, 2, 1],
paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]])
if __name__ == "__main__":
test.main()

View File

@ -209,6 +209,11 @@ class UnaryOpsTest(XLATestCase):
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.elu,
np.array([[-1, 0, 1]], dtype=dtype),
expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),

View File

@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Any", "reduction_indices"},
{"ArgMax", "dimension"},
{"AvgPoolGrad", "orig_input_shape"},
{"BatchToSpace", "crops"},
{"BatchToSpaceND", "block_shape"},
{"BatchToSpaceND", "crops"},
{"BroadcastGradientArgs", "s0"},
{"BroadcastGradientArgs", "s1"},
{"Concat", "concat_dim"},
@ -69,6 +72,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"ReverseV2", "axis"},
{"Slice", "begin"},
{"Slice", "size"},
{"SpaceToBatch", "paddings"},
{"SpaceToBatchND", "block_shape"},
{"SpaceToBatchND", "paddings"},
{"Split", "split_dim"},
{"SplitV", "split_dim"},
{"SplitV", "size_splits"},

View File

@ -15,6 +15,7 @@ tf_kernel_library(
srcs = [
"aggregate_ops.cc",
"batch_matmul_op.cc",
"batchtospace_op.cc",
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
@ -26,6 +27,7 @@ tf_kernel_library(
"depthwise_conv_ops.cc",
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
"fill_op.cc",
"function_ops.cc",
"identity_op.cc",
@ -49,6 +51,7 @@ tf_kernel_library(
"shape_op.cc",
"slice_op.cc",
"softmax_op.cc",
"spacetobatch_op.cc",
"split_op.cc",
"strided_slice_op.cc",
"tile_ops.cc",

View File

@ -0,0 +1,186 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
namespace {
void BatchToSpace(XlaOpKernelContext* ctx,
const xla::ComputationDataHandle& input, DataType input_dtype,
const TensorShape& input_tensor_shape,
gtl::ArraySlice<int64> block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
const int block_rank = block_shape.size();
OP_REQUIRES(
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
gtl::ArraySlice<int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
ctx,
xla::ShapeUtil::Rank(crops.shape()) == 2 &&
block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
errors::InvalidArgument("crops should have shape [", block_rank,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
xla::ComputationBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
int64 block_num_elems = 1;
for (int i = 0; i < block_rank; ++i) {
block_num_elems *= block_shape[i];
}
OP_REQUIRES(ctx, block_num_elems > 0,
errors::InvalidArgument(
"The product of the block dimensions must be positive"));
// 1. Reshape `input` to `reshaped` of shape:
// [block_shape[0], ..., block_shape[M-1],
// batch / prod(block_shape),
// input_shape[1], ..., input_shape[N-1]]
OP_REQUIRES(
ctx, batch_size % block_num_elems == 0,
errors::InvalidArgument("Input batch dimension (", batch_size,
") is not divisible by product of block sizes (",
block_num_elems, ")"));
std::vector<int64> reshaped_shape(input_rank + block_rank);
std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin());
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
//
// input_shape[1], block_shape[0],
// ...,
// input_shape[M], block_shape[M-1],
//
// input_shape[M+1], ..., input_shape[N-1]]
std::vector<int64> permutation(reshaped_shape.size());
permutation[0] = block_rank;
for (int i = 0; i < block_rank; ++i) {
permutation[1 + 2 * i] = block_rank + 1 + i;
permutation[1 + 2 * i + 1] = i;
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
//
// input_shape[1] * block_shape[0],
// ...,
// input_shape[M] * block_shape[M-1],
//
// input_shape[M+1],
// ...,
// input_shape[N-1]]
std::vector<int64> reshaped_permuted_shape(input_rank);
reshaped_permuted_shape[0] = batch_size / block_num_elems;
for (int i = 0; i < block_rank; ++i) {
reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i];
}
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
xla::ComputationDataHandle reshaped_permuted =
b->Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
// [batch / prod(block_shape),
//
// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
// ...,
// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
//
// input_shape[M+1], ..., input_shape[N-1]]
std::vector<int64> start_indices(input_rank, 0);
std::vector<int64> end_indices = reshaped_permuted_shape;
for (int i = 0; i < block_rank; ++i) {
int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0});
int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1});
OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
errors::InvalidArgument("Crops must be non-negative"));
start_indices[1 + i] = crop_start;
end_indices[1 + i] -= crop_end;
OP_REQUIRES(
ctx, start_indices[1 + i] <= end_indices[1 + i],
errors::InvalidArgument(
"Cropped size must be non-negative: start: ", crop_start,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
xla::ComputationDataHandle output =
b->Slice(reshaped_permuted, start_indices, end_indices);
ctx->SetOutput(0, output);
}
class BatchToSpaceNDOp : public XlaOpKernel {
public:
explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> block_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
xla::Literal crops;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops));
BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
block_shape, crops);
}
};
REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp);
class BatchToSpaceOp : public XlaOpKernel {
public:
explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
OP_REQUIRES(
ctx, block_size_ > 1,
errors::InvalidArgument("Block size should be > 1: ", block_size_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::Literal crops;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops));
BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
{block_size_, block_size_}, crops);
}
private:
int block_size_;
};
REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Native XLA implementations of XLA Elu Ops
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
namespace tensorflow {
namespace {
class EluOp : public XlaOpKernel {
public:
explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto pred = b->Gt(ctx->Input(0), zero);
const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1));
}
};
class EluGradOp : public XlaOpKernel {
public:
explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
const auto exp_grad = b->Mul(grad, b->Add(activation, one));
const auto pred = b->Gt(activation, zero);
ctx->SetOutput(0, b->Select(pred, grad, exp_grad));
}
};
REGISTER_XLA_OP(Name("Elu"), EluOp);
REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,190 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
namespace {
void SpaceToBatch(XlaOpKernelContext* ctx,
const xla::ComputationDataHandle& input, DataType input_dtype,
const TensorShape& input_tensor_shape,
gtl::ArraySlice<int64> block_shape,
const xla::Literal& paddings) {
const int input_rank = input_tensor_shape.dims();
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
const int block_rank = block_shape.size();
OP_REQUIRES(
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
gtl::ArraySlice<int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
ctx,
xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
errors::InvalidArgument("paddings should have shape [", block_rank,
", 2] instead of ",
xla::ShapeUtil::HumanString(paddings.shape())));
xla::ComputationBuilder* b = ctx->builder();
// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
// input according to `paddings` to produce `padded` of shape `padded_shape`.
xla::PaddingConfig padding_config;
std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
int64 block_num_elems = 1LL;
padding_config.add_dimensions(); // Don't pad the batch dimension.
for (int i = 0; i < block_rank; ++i) {
auto* dim = padding_config.add_dimensions();
int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0});
int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1});
OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
errors::InvalidArgument("Paddings must be non-negative"));
dim->set_edge_padding_low(pad_start);
dim->set_edge_padding_high(pad_end);
padded_shape[1 + i] += pad_start + pad_end;
block_num_elems *= block_shape[i];
}
// Don't pad the remainder dimensions.
for (int i = 0; i < remainder_shape.size(); ++i) {
padding_config.add_dimensions();
}
OP_REQUIRES(ctx, block_num_elems > 0,
errors::InvalidArgument(
"The product of the block dimensions must be positive"));
xla::ComputationDataHandle padded =
b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
// 2. Reshape `padded` to `reshaped_padded` of shape:
//
// [batch] +
// [padded_shape[1] / block_shape[0],
// block_shape[0],
// ...,
// padded_shape[M] / block_shape[M-1],
// block_shape[M-1]] +
// remaining_shape
const int64 batch_size = input_shape[0];
std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
reshaped_padded_shape[0] = batch_size;
for (int i = 0; i < block_rank; ++i) {
OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
errors::InvalidArgument("padded_shape[", 1 + i,
"]=", padded_shape[1 + i],
" is not divisible by block_shape[", i,
"]=", block_shape[i]));
reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
}
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_padded_shape.begin() + 1 + 2 * block_rank);
xla::ComputationDataHandle reshaped_padded =
b->Reshape(padded, reshaped_padded_shape);
// 3. Permute dimensions of `reshaped_padded` to produce
// `permuted_reshaped_padded` of shape:
//
// block_shape +
// [batch] +
// [padded_shape[1] / block_shape[0],
// ...,
// padded_shape[M] / block_shape[M-1]] +
// remaining_shape
std::vector<int64> permutation(reshaped_padded_shape.size());
for (int i = 0; i < block_rank; ++i) {
permutation[i] = 1 + 2 * i + 1;
permutation[block_rank + 1 + i] = 1 + 2 * i;
}
permutation[block_rank] = 0;
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
xla::ComputationDataHandle permuted_reshaped_padded =
b->Transpose(reshaped_padded, permutation);
// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
//
// [batch * prod(block_shape)] +
// [padded_shape[1] / block_shape[0],
// ...,
// padded_shape[M] / block_shape[M-1]] +
// remaining_shape
// Determine the length of the prefix of block dims that can be combined
// into the batch dimension due to having no padding and block_shape=1.
std::vector<int64> output_shape(input_rank);
output_shape[0] = batch_size * block_num_elems;
for (int i = 0; i < block_rank; ++i) {
output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
}
std::copy(remainder_shape.begin(), remainder_shape.end(),
output_shape.begin() + 1 + block_rank);
xla::ComputationDataHandle output =
b->Reshape(permuted_reshaped_padded, output_shape);
ctx->SetOutput(0, output);
}
class SpaceToBatchNDOp : public XlaOpKernel {
public:
explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> block_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
xla::Literal paddings;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
block_shape, paddings);
}
};
REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp);
class SpaceToBatchOp : public XlaOpKernel {
public:
explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
OP_REQUIRES(
ctx, block_size_ > 1,
errors::InvalidArgument("Block size should be > 1: ", block_size_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::Literal paddings;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
{block_size_, block_size_}, paddings);
}
private:
int block_size_;
};
REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp);
} // namespace
} // namespace tensorflow

View File

@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
xla::Literal* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
switch (literal.shape().element_type()) {
case xla::S32:
out->Clear();
*out->mutable_shape() = literal.shape();
out->mutable_shape()->set_element_type(xla::S64);
for (int32 x : literal.s32s()) {
out->add_s64s(x);
}
return Status::OK();
case xla::S64:
out->Swap(&literal);
return Status::OK();
default:
return errors::InvalidArgument(
"Invalid argument to ConstantInputAsInt64Literal: ",
xla::ShapeUtil::HumanString(literal.shape()));
}
}
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {

View File

@ -110,6 +110,9 @@ class XlaOpKernelContext {
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);

View File

@ -594,8 +594,10 @@ cc_test(
deps = [
":buffer_assignment",
":computation_tracker",
":copy_insertion",
":cpu_plugin",
":hlo",
":hlo_ordering",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",

View File

@ -868,7 +868,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
computation->root_instruction()->dimensions();
EXPECT_EQ(1, broadcast_dims.size());
EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 ||
broadcast_dims[3] == 3);
broadcast_dims[0] == 3);
}
TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {

View File

@ -41,6 +41,8 @@ limitations under the License.
namespace xla {
using ::tensorflow::gtl::FlatMap;
using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::HumanReadableNumBytes;
@ -394,8 +396,8 @@ Status GatherComputationsByAllocationType(
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
tensorflow::gtl::FlatSet<HloComputation*> thread_local_set;
tensorflow::gtl::FlatSet<HloComputation*> global_set;
FlatSet<HloComputation*> thread_local_set;
FlatSet<HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
@ -554,10 +556,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
Status BufferAssigner::AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
const tensorflow::gtl::FlatSet<const HloInstruction*>* hlos_to_allocate,
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
colocated_allocations,
const FlatSet<const HloInstruction*>* hlos_to_allocate,
const FlatSet<const LogicalBuffer*>& colocated_buffers,
const FlatSet<BufferAllocation::Index>& colocated_allocations,
BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
// size.
@ -578,7 +579,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// Generate a post order sort of instructions for sorting of the
// LogicalBuffers.
tensorflow::gtl::FlatMap<const HloInstruction*, int> post_order_position;
FlatMap<const HloInstruction*, int> post_order_position;
int position = 0;
for (auto* instruction : computation->MakeInstructionPostOrder()) {
post_order_position.emplace(instruction, position);
@ -590,7 +591,7 @@ Status BufferAssigner::AssignBuffersForComputation(
const BufferLiveness& liveness = assignment->liveness();
const std::vector<const HloInstruction*>* sequential_order =
liveness.hlo_ordering().SequentialOrder(*computation);
tensorflow::gtl::FlatSet<const LogicalBuffer*> unassigned_temp_buffers;
FlatSet<const LogicalBuffer*> unassigned_temp_buffers;
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
// first for simplicity. This means any previously created BufferAllocation is
@ -791,7 +792,7 @@ Status BufferAssigner::AssignBuffersForComputation(
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>& sequence,
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers_to_assign,
const FlatSet<const LogicalBuffer*>& buffers_to_assign,
const HloComputation& computation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic
// that seems to give the best results is lazy-best-fit, with all runs of
@ -881,40 +882,137 @@ void BufferAssigner::AddSetToColocatedBufferSets(
}
}
// Conceptually the same as AddSetToColocatedBufferSets, but specific to the
// colocated buffers for while instructions. 'colocated_set' contains the
// buffers for a single while instruction that must be colocated. The idea here
// is to apply a memory-saving heuristic for separate while instructions whose
// buffers are disjoint in liveness, by using the colocation mechanism to force
// buffer sharing. This often reduces memory for multi-layer RNNs.
//
// TODO(b/32491382): We should be able to remove this heuristic after we
// implement module-level liveness analysis, which would let us directly detect
// buffer sharing opportunities between the while instruction buffer and the
// buffers from the predicate and body computation, as well as sharing across
// different while instructions.
void BufferAssigner::AddWhileSetToColocatedBufferSets(
const std::vector<const LogicalBuffer*>& colocated_set,
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
const HloComputation& computation, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
CHECK(!colocated_set.empty());
// Parallel while loops cannot safely share colocated buffer sets.
if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) {
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
return;
}
// Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets
// are added in postorder over computations and instructions.
const int64 init_buffer_size = buffer_size_(*while_init_buffer);
for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) {
const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i];
// Skip predecessor sets not associated with while loops.
if (std::all_of(predecessor_set.begin(), predecessor_set.end(),
[](const LogicalBuffer* buffer) {
return buffer->instruction()->opcode() !=
HloOpcode::kWhile;
})) {
continue;
}
// Skip predecessor sets already associated with 'while_hlo'.
if (std::any_of(predecessor_set.begin(), predecessor_set.end(),
[&while_hlo](const LogicalBuffer* buffer) {
return buffer->instruction() == while_hlo;
})) {
continue;
}
// Build vector of predecessor while result buffers.
std::vector<const LogicalBuffer*> predecessor_while_buffers;
for (const LogicalBuffer* buffer : predecessor_set) {
if (buffer->instruction()->opcode() == HloOpcode::kWhile &&
buffer_size_(*buffer) == init_buffer_size &&
buffer->instruction()->parent() == &computation) {
predecessor_while_buffers.push_back(buffer);
}
}
if (predecessor_while_buffers.empty()) {
continue;
}
// Skip predecessor set if the live range of any predecessor buffers
// overlaps with 'while_init_buffer'. Note that tuple element buffer
// forwarding can cause the same buffer to appear on both sides of the
// interference comparison below.
if (std::any_of(
predecessor_while_buffers.begin(), predecessor_while_buffers.end(),
[while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) {
return while_init_buffer->id() != buffer->id() &&
buffer_liveness.MayInterfere(*while_init_buffer, *buffer);
})) {
continue;
}
// All our checks have passed; merge 'predecessor_set' with 'colocated_set',
// and add the merged set to 'colocated_buffer_sets'. This forces the
// colocation of buffers across different while instructions.
FlatSet<const LogicalBuffer*> unique;
unique.insert(predecessor_set.begin(), predecessor_set.end());
unique.insert(colocated_set.begin(), colocated_set.end());
std::vector<const LogicalBuffer*> merged_set(unique.begin(), unique.end());
AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets);
return;
}
// Failed to merge into predecessor set; add 'colocated_set' as-is.
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
}
namespace {
// Checks that points-to set of 'instruction' is unambiguous and distinct
// (ensured by CopyInsertion), then adds the buffer from the points-to set at
// 'index' to 'colocated_set'.
void AddBufferToColocatedSet(const HloInstruction* instruction,
const ShapeIndex& index,
const TuplePointsToAnalysis& points_to_analysis,
std::vector<const LogicalBuffer*>* colocated_set) {
const LogicalBuffer* AddBufferToColocatedSet(
const HloInstruction* instruction, const ShapeIndex& index,
const TuplePointsToAnalysis& points_to_analysis,
std::vector<const LogicalBuffer*>* colocated_set) {
// CopyInsertion ensures root points-to set is unambiguous and distinct.
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
CHECK(!points_to.IsAmbiguous());
CHECK(points_to.IsDistinct());
colocated_set->push_back(points_to.element(index)[0]);
return colocated_set->back();
}
} // namespace
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile and kCall).
void BufferAssigner::BuildColocatedBufferSets(
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloModule* module, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
for (auto& computation : module->computations()) {
for (auto& instruction : computation->instructions()) {
const TuplePointsToAnalysis& points_to_analysis =
buffer_liveness.points_to_analysis();
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
const HloOpcode opcode = instruction->opcode();
if (opcode == HloOpcode::kWhile) {
HloInstruction* while_hlo = instruction.get();
const HloInstruction* while_hlo = instruction;
TF_CHECK_OK(ShapeUtil::ForEachSubshape(
while_hlo->shape(),
[this, while_hlo, &points_to_analysis, colocated_buffer_sets](
const Shape& /*subshape*/, const ShapeIndex& index) {
[this, while_hlo, &points_to_analysis, &buffer_liveness,
computation, colocated_buffer_sets](const Shape& /*subshape*/,
const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
AddBufferToColocatedSet(while_hlo->operand(0), index,
points_to_analysis, &colocated_set);
auto* init_buffer =
AddBufferToColocatedSet(while_hlo->operand(0), index,
points_to_analysis, &colocated_set);
// Add while.result.
AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
&colocated_set);
@ -930,12 +1028,15 @@ void BufferAssigner::BuildColocatedBufferSets(
AddBufferToColocatedSet(
while_hlo->while_body()->root_instruction(), index,
points_to_analysis, &colocated_set);
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
AddWhileSetToColocatedBufferSets(
colocated_set, init_buffer, while_hlo, *computation,
buffer_liveness, colocated_buffer_sets);
return Status::OK();
}));
} else if (opcode == HloOpcode::kCall) {
HloInstruction* call_hlo = instruction.get();
HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction();
const HloInstruction* call_hlo = instruction;
const HloInstruction* root_hlo =
call_hlo->to_apply()->root_instruction();
TF_CHECK_OK(ShapeUtil::ForEachSubshape(
call_hlo->shape(),
[this, call_hlo, root_hlo, &points_to_analysis,
@ -961,8 +1062,8 @@ void BufferAssigner::BuildColocatedBufferSets(
void BufferAssigner::AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations) {
FlatSet<const LogicalBuffer*>* colocated_buffers,
FlatSet<BufferAllocation::Index>* colocated_allocations) {
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
BufferAllocation* allocation = nullptr;
for (const LogicalBuffer* buffer : colocated_buffer_set) {
@ -1008,9 +1109,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
// AssignBuffersForComputation for fast membership testing.
std::unique_ptr<tensorflow::gtl::FlatSet<const HloInstruction*>> hlo_set;
std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set;
if (hlos_to_allocate != nullptr) {
hlo_set = MakeUnique<tensorflow::gtl::FlatSet<const HloInstruction*>>(
hlo_set = MakeUnique<FlatSet<const HloInstruction*>>(
hlos_to_allocate->begin(), hlos_to_allocate->end());
}
@ -1022,11 +1123,11 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Once b/32491382 enables module-level liveness analysis, we may be able
// to assign colocated buffers (or at least reuse their allocation for
// buffers outside of the set) in AssignBuffersForComputation.
tensorflow::gtl::FlatSet<const LogicalBuffer*> colocated_buffers;
tensorflow::gtl::FlatSet<BufferAllocation::Index> colocated_allocations;
FlatSet<const LogicalBuffer*> colocated_buffers;
FlatSet<BufferAllocation::Index> colocated_allocations;
if (colocate_related_buffers_) {
std::vector<ColocatedBufferSet> colocated_buffer_sets;
BuildColocatedBufferSets(module, assignment->points_to_analysis(),
BuildColocatedBufferSets(module, assignment->liveness(),
&colocated_buffer_sets);
AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
&colocated_buffers, &colocated_allocations);

View File

@ -465,7 +465,7 @@ class BufferAssigner {
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
// which should be colocated in the same buffer allocation.
void BuildColocatedBufferSets(
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloModule* module, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
// For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
@ -482,6 +482,14 @@ class BufferAssigner {
const std::vector<const LogicalBuffer*>& colocated_set,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
// Conceptually the same as AddSetToColocatedBufferSets, but specific to the
// colocated buffers for while instructions.
void AddWhileSetToColocatedBufferSets(
const std::vector<const LogicalBuffer*>& colocated_set,
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
const HloComputation& computation, const BufferLiveness& buffer_liveness,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
const HloModule* module_;
// Function which returns the buffer size for a given logical buffer (shape).

View File

@ -23,10 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -1245,6 +1247,163 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
}
}
} // namespace
class WhileBufferAssignmentTest : public HloTestBase {
protected:
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
const string& name) {
auto builder = HloComputation::Builder(name);
builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto ten = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
return builder.Build();
}
std::unique_ptr<HloComputation> BuildWhileBodyComputation(
const string& name) {
auto builder = HloComputation::Builder(name);
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto input = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
auto weights = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
auto output = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kMultiply, input, weights));
builder.AddInstruction(
HloInstruction::CreateTuple({input, weights, output}));
return builder.Build();
}
void RunCopyInsertion(HloModule* module) {
CopyInsertion copy_insertion;
EXPECT_IS_OK(copy_insertion.Run(module).status());
}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
module, MakeUnique<SequentialHloOrdering>(module, sequence),
ByteSizeOf, alignment)
.ConsumeValueOrDie();
}
static int64 ByteSizeOf(const LogicalBuffer& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
}
Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
Shape loop_state_shape_ =
ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
};
TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
auto module = MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto weights1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body0 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output0}));
auto while0 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
auto cond1 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body1 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto input1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({input1, weights1, output1}));
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
auto assignment = RunBufferAssignment(module.get());
// While instruction 'while0' has no predecessor while instructions with
// which to share allocations.
// While instruction 'while1' can share allocations with the following
// buffers:
// *) while0[2], while1[0]
// *) while0[1], while1[1]
EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
}
TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto module = MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body0 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output0}));
auto while0 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
auto cond1 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body1 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output1}));
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
auto assignment = RunBufferAssignment(module.get());
EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
}
} // namespace
} // namespace xla

View File

@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) {
}
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
case HloOpcode::kTuple:
return true;
default:
return false;
}
}
} // namespace xla

View File

@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
// Returns true iff the given opcode is a comparison operation.
bool HloOpcodeIsComparison(HloOpcode opcode);
// Returns true iff the given opcode has variadic operands.
bool HloOpcodeIsVariadic(HloOpcode opcode);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_

View File

@ -40,6 +40,8 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module,
} // namespace
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
run_called_ = true;
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
std::vector<string> tmp =

View File

@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface {
// Returns a reference to the added pass.
template <typename T, typename... Args>
T& AddPass(Args&&... args) {
CHECK(!run_called_) << "AddPass cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
passes_.push_back(std::unique_ptr<T>(pass));
return *pass;
@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface {
// (it is required to always return "false" from its Run() method).
template <typename T, typename... Args>
T& AddInvariantChecker(Args&&... args) {
CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
return *pass;
@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface {
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
bool run_called_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};

View File

@ -101,12 +101,12 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
} // namespace
// User and operand can share buffers iff both instructions emit the same shape
// and layout, and 'user' meets one of the following two qualifications:
// *) Is element-wise.
// and layout, and 'user' meets one of the following qualifications:
// *) Is element-wise. Or...
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
// at operand 0. Or...
// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
@ -144,7 +144,8 @@ bool CanShareOperandBufferWithUser(
break;
}
return false;
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
std::vector<int64> operand_indices = user->OperandIndices(operand);

View File

@ -185,5 +185,73 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
*points_to_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
auto update = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "update"));
auto starts = builder.AddInstruction(
HloInstruction::CreateParameter(2, starts_shape, "starts"));
auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, data, update, starts));
BuildModuleAndRunAnalysis(builder.Build());
// The DynamicUpdateSlice instruction can share with the data operand, but not
// with update or starts.
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
auto make_cond = [this, &data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
return builder.Build();
};
auto make_body = [this, &data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
return builder.Build();
};
module_ = MakeUnique<HloModule>(TestName());
HloComputation* cond_computation =
module_->AddEmbeddedComputation(make_cond());
HloComputation* body_computation =
module_->AddEmbeddedComputation(make_body());
auto builder = HloComputation::Builder(TestName());
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
data_shape, cond_computation, body_computation, data));
computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis();
// The While instruction can share with the data operand.
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
}
} // namespace
} // namespace xla

View File

@ -265,6 +265,37 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
*LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
auto builder = HloComputation::Builder(TestName());
Array3D<float> input_vals(2, 3, 4);
input_vals.FillRandom(1.0);
Array4D<float> expected(2, 3, 4, 5);
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 4; ++k) {
for (int m = 0; m < 5; ++m) {
expected(i, j, k, m) = input_vals(i, j, k);
}
}
}
}
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
// Broadcast vector in dimensions 2 and 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
// Create HLO module, compile, and execute.
auto hlo_module = MakeUnique<HloModule>(TestName());
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear(
*LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
}
} // namespace
} // namespace xla

View File

@ -211,9 +211,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
@ -221,6 +221,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
RunR1ToR0Test(16 * 1024 + 1);
}
XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }

View File

@ -176,6 +176,52 @@ cc_binary(
],
)
cc_library(
name = "hlo_tfgraph_builder",
srcs = ["hlo_tfgraph_builder.cc"],
hdrs = ["hlo_tfgraph_builder.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "hlo_tfgraph_builder_test",
srcs = ["hlo_tfgraph_builder_test.cc"],
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test_main",
],
)
cc_binary(
name = "dumped_computation_to_tf_graphdef",
srcs = ["dumped_computation_to_tf_graphdef.cc"],
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,139 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Usage: dumped_computation_to_tf_graph \
// --output_dir=/tmp/graphs/ some_binary_snapshot_proto*
//
// Dumps a tensorflow GraphDef in text format for a snapshot computation. The
// dumped graph is an HLO computation with HLO instructions as nodes and can be
// visualized on Tensorboard. Upload the dumped files on Tensorboard.
//
// some_binary_snapshot_proto is obtained by serializing the SessionModule from
// ServiceInterface::SnapshotComputation to disk.
#include <stdio.h>
#include <memory>
#include <string>
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::Env;
using tensorflow::io::JoinPath;
using tensorflow::strings::StrAppend;
namespace xla {
namespace tools {
namespace {
// Dumps all computations in the module to the given directory.
void DumpTfGraph(const HloModule& module, const string& directory_path) {
Env* env = Env::Default();
TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
string fname = module.name();
std::replace(fname.begin(), fname.end(), '/', '_');
// Since the file name will be used as the top-level scope name, clean it up
// to make it a valid scope name.
CleanNodeName(&fname);
StrAppend(&fname, ".pbtxt");
string path = JoinPath(directory_path, fname);
HloTfGraphBuilder builder;
TF_CHECK_OK(builder.AddComputation(*module.entry_computation()));
std::cout << "Dumping " << module.name() << " to " << path << std::endl;
TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef()));
}
} // namespace
void RealMain(tensorflow::gtl::ArraySlice<char*> args,
const string& output_dir) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
// To avoid adding a new flag, use local service and lower the computations
// locally.
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
// Build HloModule for each Computation and dump to file.
for (char* arg : args) {
SessionModule session_module;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
&session_module));
auto computation_status = client->LoadSnapshot(session_module);
if (!computation_status.ok()) {
fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
computation_status.status().ToString().c_str());
continue;
}
Computation computation = computation_status.ConsumeValueOrDie();
StatusOr<UserComputation*> user_computation_status =
local_service->computation_tracker().Resolve(computation.handle());
if (!user_computation_status.ok()) {
fprintf(stderr,
"failed to resolve computation to UserComputation %s: %s\n", arg,
user_computation_status.status().ToString().c_str());
continue;
}
auto* user_computation = user_computation_status.ValueOrDie();
StatusOr<std::unique_ptr<HloModule>> module_status =
local_service->computation_tracker().BuildHloModule(
user_computation->GetVersionedHandle());
if (!module_status.ok()) {
fprintf(stderr, "failed to build HloModule %s: %s\n", arg,
module_status.status().ToString().c_str());
continue;
}
DumpTfGraph(*module_status.ValueOrDie(), output_dir);
}
}
} // namespace tools
} // namespace xla
int main(int argc, char** argv) {
string output_dir = "";
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
"Directory to write GraphDef data to."),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_ok || output_dir.empty()) {
LOG(QFATAL) << usage;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
args.pop_front(); // Pop off the binary name, argv[0]
xla::tools::RealMain(args, output_dir);
return 0;
}

View File

@ -0,0 +1,204 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
LIcensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
using ::tensorflow::GraphDef;
using ::tensorflow::NodeDef;
using ::tensorflow::TensorShapeProto;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
using ::tensorflow::str_util::Join;
namespace xla {
namespace tools {
namespace {
string GetOpDefName(const HloInstruction* instruction) {
string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
tensorflow::str_util::TitlecaseString(&name, "-");
name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
if (instruction->opcode() == HloOpcode::kFusion) {
string fusion_name = ToString(instruction->fusion_kind());
StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
}
return name;
}
TensorShapeProto GetTensorShape(const HloInstruction* instruction) {
TensorShapeProto tensor_shape;
const Shape& shape = instruction->shape();
for (auto dim : shape.dimensions()) {
tensor_shape.add_dim()->set_size(dim);
}
return tensor_shape;
}
} // namespace
void CleanNodeName(string* name) {
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
const string chars_to_replace = "<>[]";
auto pred = [&](char c) {
return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
chars_to_replace.end();
};
std::replace_if(name->begin(), name->end(), pred, '_');
}
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
LOG(INFO) << "Adding computation " << computation.name();
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
LOG(INFO) << "Adding embedded computation " << embedded->name();
for (auto& instruction : embedded->instructions()) {
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
}
}
for (auto& instruction : computation.instructions()) {
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
}
return Status::OK();
}
const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; }
const string& HloTfGraphBuilder::GetNodeNameForInstruction(
const HloInstruction* instruction) {
if (ContainsKey(instruction_to_node_name_, instruction)) {
return instruction_to_node_name_[instruction];
}
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
string node_name =
instruction->IsFused()
? GetNodeNameForInstruction(instruction->fusion_instruction())
: instruction->parent()->name();
string instruction_name = instruction->name();
if (instruction->opcode() == HloOpcode::kParameter) {
StrAppend(&instruction_name, ".", instruction->parameter_number());
}
StrAppend(&node_name, "/", instruction_name);
CleanNodeName(&node_name);
auto ret =
instruction_to_node_name_.insert(std::make_pair(instruction, node_name));
CHECK(ret.second);
return ret.first->second;
}
void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
NodeDef* node_def) const {
auto& attrs = *node_def->mutable_attr();
// Set the number of arguments for instructions that have variadic operands.
if (HloOpcodeIsVariadic(instruction->opcode())) {
tensorflow::AttrValue attr_value;
attr_value.set_i(instruction->operands().size());
attrs["arg_num"] = attr_value;
}
// Set the node type.
attrs["type"].set_s(
xla::PrimitiveType_Name(instruction->shape().element_type()));
// Set the shape of the output tensor. "_output_shapes" is a special attribute
// name used by Tensorboard for shapes of output tensors.
tensorflow::AttrValue shapes;
*shapes.mutable_list()->add_shape() = GetTensorShape(instruction);
attrs["_output_shapes"] = shapes;
// Set the layout.
if (LayoutUtil::HasLayout(instruction->shape())) {
string layout_string;
if (ShapeUtil::IsTuple(instruction->shape())) {
// For tuples, emit the full shape because the layout of a tuple is not
// represented in a single Layout field.
layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
} else {
layout_string = StrCat(
"{", Join(instruction->shape().layout().minor_to_major(), ","), "}");
}
attrs["layout"].set_s(layout_string);
}
// Set op-specific attributes.
switch (instruction->opcode()) {
case HloOpcode::kConcatenate:
case HloOpcode::kBroadcast:
case HloOpcode::kReduce:
case HloOpcode::kReverse:
case HloOpcode::kTranspose:
for (auto dim : instruction->dimensions()) {
attrs["dims"].mutable_list()->add_i(dim);
}
break;
case HloOpcode::kGetTupleElement:
attrs["index"].set_i(instruction->tuple_index());
break;
case HloOpcode::kRng:
attrs["dist"].set_s(
RandomDistribution_Name(instruction->random_distribution()));
break;
case HloOpcode::kConstant:
if (ShapeUtil::IsScalar(instruction->shape())) {
attrs["value"].set_s(
LiteralUtil::GetAsString(instruction->literal(), {}));
}
break;
case HloOpcode::kCustomCall:
attrs["custom_call_target"].set_s(instruction->custom_call_target());
break;
default:
break;
}
}
Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
if (!visited_instructions_.insert(instruction).second) {
// Skip instructions that have already been added.
return Status::OK();
}
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {
for (auto& fused_instruction : instruction->fused_instructions()) {
TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get()));
}
}
// Add all edges including control edges.
for (unsigned i = 0; i < instruction->operands().size(); ++i) {
*node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i));
}
// Called computations are control dependencies.
for (const auto* called_computation : instruction->called_computations()) {
*node_def->add_input() = StrCat(
"^", GetNodeNameForInstruction(called_computation->root_instruction()));
}
return Status::OK();
}
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph.h"
namespace xla {
namespace tools {
// This constructs a tensorflow graph for HLO computations.
class HloTfGraphBuilder {
public:
// Adds a computation to the graph.
Status AddComputation(const HloComputation& computation);
const tensorflow::GraphDef& GetGraphDef() const;
private:
// Gets the node name of an instruction. The node name is hierarchical. For
// example, if an instruction is fused, it will be put in a subgraph of the
// fusion instruction.
const string& GetNodeNameForInstruction(const HloInstruction* instruction);
void SetNodeAttrs(const HloInstruction* instruction,
tensorflow::NodeDef* node_def) const;
Status AddInstruction(const HloInstruction* instruction);
tensorflow::GraphDef graph_def_;
// This records instructions that have been visited.
std::unordered_set<const HloInstruction*> visited_instructions_;
// A cache that maps instruction to the node name.
std::unordered_map<const HloInstruction*, string> instruction_to_node_name_;
};
// Cleans the node name to make it a valid name in a tensorflow graph.
void CleanNodeName(string* name);
} // namespace tools
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_

View File

@ -0,0 +1,154 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace tools {
namespace {
using ::tensorflow::GraphDef;
class HloTfGraphBuilderTest : public HloTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
// Create a computation which takes a scalar and returns its negation.
std::unique_ptr<HloComputation> CreateNegateComputation() {
auto builder = HloComputation::Builder("Negate");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
return builder.Build();
}
// Creates a computation which calls map with the given computation.
std::unique_ptr<HloComputation> CreateMapComputation(
HloComputation* map_computation) {
auto builder = HloComputation::Builder("Map");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
builder.AddInstruction(
HloInstruction::CreateMap(r0f32_, {param}, map_computation));
return builder.Build();
}
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
auto builder = HloComputation::Builder("Concatenate");
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
auto param_1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
auto param_2 = builder.AddInstruction(
HloInstruction::CreateParameter(1, shape, "param1"));
builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(F32, {2, 4}), {param_1, param_2}, 1));
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 3);
const auto &node = graph_def.node(2);
EXPECT_EQ(node.name(), "Concatenate/concatenate");
// Check dimensions.
auto dims_value = node.attr().find("dims");
CHECK(dims_value != node.attr().end());
EXPECT_EQ(dims_value->second.list().i_size(), 1);
EXPECT_EQ(dims_value->second.list().i(0), 1);
// Check shapes.
auto shape_value = node.attr().find("_output_shapes");
CHECK(shape_value != node.attr().end());
EXPECT_EQ(shape_value->second.list().shape_size(), 1);
EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
}
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
auto builder = HloComputation::Builder("Const");
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 1);
const auto &node = graph_def.node(0);
auto value = node.attr().find("value");
CHECK(value != node.attr().end());
EXPECT_EQ(value->second.s(), "123");
auto type = node.attr().find("type");
CHECK(type != node.attr().end());
EXPECT_EQ(type->second.s(), "S32");
}
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
auto negate_computation = CreateNegateComputation();
TF_CHECK_OK(generator_.AddComputation(*negate_computation));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 2);
EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0");
EXPECT_EQ(graph_def.node(0).op(), "HloParameter");
EXPECT_EQ(graph_def.node(1).name(), "Negate/negate");
EXPECT_EQ(graph_def.node(1).op(), "HloNegate");
EXPECT_EQ(graph_def.node(1).input_size(), 1);
EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0");
}
TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
auto builder = HloComputation::Builder("GE");
auto param_1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto param_2 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 3);
EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
EXPECT_EQ(graph_def.node(2).input_size(), 2);
EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to");
EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
}
TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
// Create computations with a diamond-shaped callgraph.
auto negate_computation = CreateNegateComputation();
auto map1_computation = CreateMapComputation(negate_computation.get());
auto map2_computation = CreateMapComputation(negate_computation.get());
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto map1 = builder.AddInstruction(
HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
auto map2 = builder.AddInstruction(
HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
auto computation = builder.Build();
TF_CHECK_OK(generator_.AddComputation(*computation));
EXPECT_GT(generator_.GetGraphDef().node_size(), 0);
}
} // namespace
} // namespace tools
} // namespace xla

View File

@ -42,12 +42,10 @@ class StochasticTensorTest(test.TestCase):
sigma2 = constant_op.constant([0.1, 0.2, 0.3])
prior_default = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma))
distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma),
distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.SampleValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))
@ -55,8 +53,7 @@ class StochasticTensorTest(test.TestCase):
prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
distributions.Normal(
loc=prior, scale=sigma2))
distributions.Normal(loc=prior, scale=sigma2))
self.assertTrue(isinstance(likelihood.value_type, st.SampleValue))
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
@ -102,8 +99,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma))
distributions.Normal(loc=mu, scale=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@ -113,8 +109,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma))
distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
prior_single_value = prior_single.value()
@ -125,8 +120,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma))
distributions.Normal(loc=mu, scale=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
@ -163,8 +157,7 @@ class StochasticTensorTest(test.TestCase):
# With passed-in loss_fn.
dt = st.StochasticTensor(
distributions.Normal(
loc=mu, scale=sigma),
distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.MeanValue(stop_gradient=True),
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=constant_op.constant(8.0)))
@ -199,8 +192,7 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = constant_op.constant([1.1, 1.2, 1.3])
obs = array_ops.zeros((2, 3))
z = st.ObservedStochasticTensor(
distributions.Normal(
loc=mu, scale=sigma), value=obs)
distributions.Normal(loc=mu, scale=sigma), value=obs)
[obs_val, z_val] = sess.run([obs, z.value()])
self.assertAllEqual(obs_val, z_val)
@ -212,15 +204,13 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = array_ops.placeholder(dtypes.float32)
obs = array_ops.placeholder(dtypes.float32)
z = st.ObservedStochasticTensor(
distributions.Normal(
loc=mu, scale=sigma), value=obs)
distributions.Normal(loc=mu, scale=sigma), value=obs)
mu2 = array_ops.placeholder(dtypes.float32, shape=[None])
sigma2 = array_ops.placeholder(dtypes.float32, shape=[None])
obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
z2 = st.ObservedStochasticTensor(
distributions.Normal(
loc=mu2, scale=sigma2), value=obs2)
distributions.Normal(loc=mu2, scale=sigma2), value=obs2)
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [z, z2])
@ -231,22 +221,18 @@ class ObservedStochasticTensorTest(test.TestCase):
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
loc=mu, scale=sigma),
distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3,)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
loc=mu, scale=sigma),
distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3, 1)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
loc=mu, scale=sigma),
value=array_ops.zeros(
(1, 2), dtype=dtypes.int32))
distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((1, 2), dtype=dtypes.int32))
if __name__ == "__main__":

View File

@ -135,8 +135,9 @@ from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['ConditionalDistribution',
'ConditionalTransformedDistribution',
'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED']
_allowed_symbols = [
'ConditionalDistribution', 'ConditionalTransformedDistribution',
'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED'
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -488,9 +488,7 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_identity_multiplier=2.,
scale_perturb_diag=[2., 1],
scale_perturb_factor=[[2., 0],
[0., 0],
[0, 1]])
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 2, 3])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
@ -526,9 +524,7 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_diag=[2., 3, 4],
scale_perturb_diag=[2., 1],
scale_perturb_factor=[[2., 0],
[0., 0],
[0, 1]])
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 3, 5])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
@ -561,17 +557,11 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to scale = [[10, 0, 0], [1, 3, 0], [2, 3, 5]]
bijector = affine_lib.Affine(
shift=mu,
scale_tril=[[2., 0, 0],
[1, 3, 0],
[2, 3, 4]],
scale_tril=[[2., 0, 0], [1, 3, 0], [2, 3, 4]],
scale_perturb_diag=[2., 1],
scale_perturb_factor=[[2., 0],
[0., 0],
[0, 1]])
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = affine_lib.Affine(
shift=mu, scale_tril=[[10., 0, 0],
[1, 3, 0],
[2, 3, 5]])
shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]])
self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.

View File

@ -70,7 +70,8 @@ class ChainBijectorTest(test.TestCase):
softmax_centered_lib.SoftmaxCentered(
event_ndims=1, validate_args=True),
softmax_centered_lib.SoftmaxCentered(
event_ndims=0, validate_args=True)])
event_ndims=0, validate_args=True)
])
x = tensor_shape.TensorShape([])
y = tensor_shape.TensorShape([2 + 1])
self.assertAllEqual(y, bijector.forward_event_shape(x))

View File

@ -36,17 +36,19 @@ class SigmoidBijectorTest(test.TestCase):
y = special.expit(x)
ildj = -np.log(y) - np.log1p(-y)
self.assertAllClose(
y, sigmoid.Sigmoid().forward(x).eval(),
atol=0., rtol=1e-2)
y, sigmoid.Sigmoid().forward(x).eval(), atol=0., rtol=1e-2)
self.assertAllClose(
x, sigmoid.Sigmoid().inverse(y).eval(),
atol=0., rtol=1e-4)
x, sigmoid.Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4)
self.assertAllClose(
ildj, sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(),
atol=0., rtol=1e-6)
ildj,
sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(),
atol=0.,
rtol=1e-6)
self.assertAllClose(
-ildj, sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(),
atol=0., rtol=1e-4)
-ildj,
sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(),
atol=0.,
rtol=1e-4)
def testScalarCongruency(self):
with self.test_session():

View File

@ -171,11 +171,12 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits, probs=probs, validate_args=validate_args)
super(RelaxedBernoulli, self).__init__(
distribution=logistic.Logistic(self._logits / self._temperature,
1. / self._temperature,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name + "/Logistic"),
distribution=logistic.Logistic(
self._logits / self._temperature,
1. / self._temperature,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name + "/Logistic"),
bijector=sigmoid_lib.Sigmoid(validate_args=validate_args),
validate_args=validate_args,
name=name)

View File

@ -3614,7 +3614,7 @@ _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:
_config = json.load(open(_config_path))
except json.decoder.JSONDecodeError:
except ValueError:
_config = {}
_floatx = _config.get('floatx', floatx())
assert _floatx in {'float16', 'float32', 'float64'}

View File

@ -379,7 +379,10 @@ def batch_norm(inputs,
fused=False,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None):
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@ -446,6 +449,19 @@ def batch_norm(inputs,
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
scope: Optional scope for `variable_scope`.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
scalar `Tensors` used to clip the renorm correction. The correction
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
dmax are set to inf, 0, inf, respectively.
renorm_decay: Momentum used to update the moving means and standard
deviations with renorm. Unlike `momentum`, this affects training
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `decay` is still applied
to get the means and variances for inference.
Returns:
A `Tensor` representing the output of the operation.
@ -464,6 +480,8 @@ def batch_norm(inputs,
if param_regularizers is not None:
raise ValueError('Regularizers are not currently '
'supported for fused batch norm.')
if renorm:
raise ValueError('Renorm is not supported for fused batch norm.')
return _fused_batch_norm(
inputs,
decay=decay,
@ -524,6 +542,9 @@ def batch_norm(inputs,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
trainable=trainable,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_decay,
name=sc.name,
_scope=sc,
_reuse=reuse)
@ -551,6 +572,9 @@ def batch_norm(inputs,
# Custom updates collections are not supported because the update logic
# is different in this case, in particular w.r.t. "forced updates" and
# update op reuse.
if renorm:
raise ValueError('renorm is not supported with batch_weights, '
'updates_collections or zero_debias_moving_mean')
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
@ -1241,6 +1265,13 @@ def flatten(inputs,
def _sparse_inner_flatten(inputs, new_rank):
"""Helper function for `inner_flatten`."""
inputs_rank = inputs.dense_shape.get_shape().as_list()[0]
if inputs_rank < new_rank:
raise ValueError(
'Inputs has rank less than new_rank. {} must have rank at least'
' {}. Received rank {}, shape {}'.format(inputs, new_rank, inputs_rank,
inputs.get_shape()))
outer_dimensions = inputs.dense_shape[:new_rank - 1]
inner_dimensions = inputs.dense_shape[new_rank - 1:]
new_shape = array_ops.concat((outer_dimensions,

View File

@ -1465,6 +1465,30 @@ class PartialFlattenTest(test.TestCase):
flattened5 = _layers._inner_flatten(inputs, 5)
self.assertEqual([2, None, 4, None, 30], flattened5.get_shape().as_list())
def testDenseFlattenRankAssertion(self):
"""Test `_inner_flatten` rank assertion for dense tensors."""
shape = [2, 3]
new_rank = 3
inputs = array_ops.placeholder(dtypes.int32)
inputs.set_shape(shape)
with self.assertRaisesRegexp(ValueError,
'inputs has rank less than new_rank'):
_layers._inner_flatten(inputs, new_rank)
def testSparseFlattenRankAssertion(self):
"""Test `_inner_flatten` rank assertion for sparse tensors."""
shape = [2, 3]
new_rank = 3
np.random.seed(10301)
random_ = np.random.rand(*shape)
indices, values, _ = _sparsify(random_)
inputs = sparse_tensor.SparseTensor(indices, values, shape)
with self.assertRaisesRegexp(ValueError,
'Inputs has rank less than new_rank'):
_layers._inner_flatten(inputs, new_rank)
class FCTest(test.TestCase):

View File

@ -27,7 +27,8 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.framework import dtypes
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
def _read32(bytestream):

View File

@ -362,6 +362,11 @@ class BaseEstimator(
self._config = config
logging.info('Using config: %s', str(vars(self._config)))
if self._config.session_config is None:
self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
self._session_config = self._config.session_config
# Model directory.
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
@ -829,7 +834,7 @@ class BaseEstimator(
eval_ops=update_op,
final_ops=eval_dict,
hooks=hooks,
config=config_pb2.ConfigProto(allow_soft_placement=True))
config=self._session_config)
current_global_step = eval_results[global_step_key]
_write_dict_to_summary(eval_dir, eval_results, current_global_step)
@ -864,7 +869,7 @@ class BaseEstimator(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=infer_ops.scaffold,
config=config_pb2.ConfigProto(allow_soft_placement=True)))
config=self._session_config))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
@ -976,7 +981,7 @@ class BaseEstimator(
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=config_pb2.ConfigProto(allow_soft_placement=True)
config=self._session_config
) as mon_sess:
loss = None
while not mon_sess.should_stop():

View File

@ -53,12 +53,17 @@ class ModeKeys(object):
EVAL = 'eval'
INFER = 'infer'
@classmethod
def validate(cls, key):
if key not in (cls.TRAIN, cls.EVAL, cls.INFER):
raise ValueError('Invalid mode %s.' % key)
class ModelFnOps(
collections.namedtuple('ModelFnOps', [
'predictions', 'loss', 'train_op', 'eval_metric_ops',
'output_alternatives', 'training_chief_hooks', 'training_hooks',
'scaffold'
'scaffold', 'mode'
])):
"""Ops returned from a model_fn."""
@ -119,6 +124,8 @@ class ModelFnOps(
Raises:
ValueError: If validation fails.
"""
ModeKeys.validate(mode)
# Assert all ops are from the same graph.
get_graph_from_inputs((predictions, loss, train_op))
@ -183,14 +190,13 @@ class ModelFnOps(
output_alternatives=output_alternatives,
training_chief_hooks=training_chief_hooks,
training_hooks=training_hooks,
scaffold=scaffold)
scaffold=scaffold,
mode=mode)
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
def estimator_spec(self, default_serving_output_alternative_key=None):
"""Creates an equivalent `EstimatorSpec`.
Args:
mode: One of `ModeKeys`. Specifies if this training, evaluation or
prediction.
default_serving_output_alternative_key: Required for multiple heads. If
you have multiple entries in `output_alternatives` dict (comparable to
multiple heads), `EstimatorSpec` requires a default head that will be
@ -265,7 +271,7 @@ class ModelFnOps(
return result
return core_model_fn_lib.EstimatorSpec(
mode=mode,
mode=self.mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,

View File

@ -80,18 +80,20 @@ class ModelFnopsTest(test.TestCase):
def testEstimatorSpec_except_export(self):
predictions = self.create_predictions()
model_fn_ops = self.create_model_fn_ops(predictions, None)
model_fn_ops = self.create_model_fn_ops(
predictions, None, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
def testEstimatorSpec_export_regression_with_scores(self):
predictions = self.create_predictions()
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -108,9 +110,10 @@ class ModelFnopsTest(test.TestCase):
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION,
output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -124,9 +127,10 @@ class ModelFnopsTest(test.TestCase):
predictions = self.create_predictions()
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -145,9 +149,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["scores"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -167,9 +172,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["probabilities"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -187,9 +193,10 @@ class ModelFnopsTest(test.TestCase):
del output_alternatives_predictions["classes"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -208,9 +215,10 @@ class ModelFnopsTest(test.TestCase):
[1, 2, 3])
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -226,9 +234,10 @@ class ModelFnopsTest(test.TestCase):
predictions = self.create_predictions()
output_alternatives = {"logistic_head": (
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -245,9 +254,10 @@ class ModelFnopsTest(test.TestCase):
output_alternatives = {"unspecified_head": (
constants.ProblemType.UNSPECIFIED, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec()
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():
@ -263,10 +273,10 @@ class ModelFnopsTest(test.TestCase):
constants.ProblemType.LINEAR_REGRESSION, predictions),
"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
model_fn_ops = self.create_model_fn_ops(
predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
"regression_head")
estimator_spec = model_fn_ops.estimator_spec("regression_head")
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
with session.Session():

View File

@ -214,7 +214,8 @@ class RunConfig(ClusterConfig):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
evaluation_master='',
model_dir=None):
model_dir=None,
session_config=None):
"""Constructor.
Note that the superclass `ClusterConfig` may set properties like
@ -246,6 +247,9 @@ class RunConfig(ClusterConfig):
evaluation_master: the master on which to perform evaluation.
model_dir: directory where model parameters, graph etc are saved. If
`None`, see `Estimator` about where the model will be saved.
session_config: a ConfigProto used to set session parameters, or None.
Note - using this argument, it is easy to provide settings which break
otherwise perfectly good models. Use with care.
"""
super(RunConfig, self).__init__(
master=master, evaluation_master=evaluation_master)
@ -261,6 +265,7 @@ class RunConfig(ClusterConfig):
self._tf_random_seed = tf_random_seed
self._save_summary_steps = save_summary_steps
self._save_checkpoints_secs = save_checkpoints_secs
self._session_config = session_config
if save_checkpoints_secs == RunConfig._USE_DEFAULT:
if save_checkpoints_steps is None:
self._save_checkpoints_secs = 600
@ -345,6 +350,10 @@ class RunConfig(ClusterConfig):
def save_checkpoints_steps(self):
return self._save_checkpoints_steps
@property
def session_config(self):
return self._session_config
@property
def keep_checkpoint_max(self):
return self._keep_checkpoint_max

View File

@ -118,7 +118,8 @@ class Experiment(object):
occur if no new snapshot is available, hence, this is the minimum.
delay_workers_by_global_step: if `True` delays training workers
based on global step instead of time.
export_strategies: A list of `ExportStrategy`s, or a single one, or None.
export_strategies: Iterable of `ExportStrategy`s, or a single one, or
`None`.
train_steps_per_iteration: (applies only to continuous_train_and_eval).
Perform this many (integer) number of train steps for each
training-evaluation iteration. With a small value, the model will be
@ -184,16 +185,19 @@ class Experiment(object):
def eval_steps(self):
return self._eval_steps
def _set_export_strategies(self, value):
if value is None:
self._export_strategies = []
elif isinstance(value, list):
self._export_strategies = value[:]
elif isinstance(value, export_strategy.ExportStrategy):
self._export_strategies = [value]
else:
raise ValueError("`export_strategies` must be an ExportStrategy, "
"a list of ExportStrategies, or None.")
def _set_export_strategies(self, values): # pylint: disable=missing-docstring
export_strategies = []
if values:
if isinstance(values, export_strategy.ExportStrategy):
export_strategies.append(values)
else:
for value in values:
if not isinstance(value, export_strategy.ExportStrategy):
raise ValueError("`export_strategies` must be an ExportStrategy,"
" an iterable of ExportStrategy, or `None`,"
" found %s." % value)
export_strategies.append(value)
self._export_strategies = tuple(export_strategies)
def extend_train_hooks(self, additional_hooks):
"""Extends the hooks for training."""

View File

@ -484,6 +484,25 @@ class ExperimentTest(test.TestCase):
self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors)
self.assertAllEqual([noop_hook], input_hooks)
def test_invalid_export_strategies(self):
for est in self._estimators_for_tests():
with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
train_steps=100,
eval_steps=100,
export_strategies='not_an_export_strategy')
with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
train_steps=100,
eval_steps=100,
export_strategies=['not_an_export_srategy'])
def test_export_strategies_reset(self):
for est in self._estimators_for_tests():
eval_metrics = 'eval_metrics' if not isinstance(
@ -498,7 +517,7 @@ class ExperimentTest(test.TestCase):
eval_metrics=eval_metrics,
train_steps=100,
eval_steps=100,
export_strategies=[export_strategy_1])
export_strategies=(export_strategy_1,))
ex.train_and_evaluate()
self.assertEqual(1, est.export_count)
@ -728,7 +747,7 @@ class ExperimentTest(test.TestCase):
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
export_strategies=[exp_strategy])
export_strategies=(exp_strategy,))
ex.test()
self.assertEqual(1, est.fit_count)
self.assertEqual(1, est.eval_count)

View File

@ -131,4 +131,5 @@ def generator_input_fn(x,
target = features.pop(target_key[0])
return features, target
return features
return _generator_input_fn

View File

@ -35,7 +35,7 @@ from tensorflow.python.training import queue_runner_impl
class GeneratorIoTest(test.TestCase):
def testGeneratorInputFn(self):
def generator():

View File

@ -359,8 +359,9 @@ def _read_keyed_batch_examples_helper(file_pattern,
# Check input parameters are given and reasonable.
if (not queue_capacity) or (queue_capacity <= 0):
raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
if (batch_size is None) or ((not isinstance(batch_size, ops.Tensor)) and
(batch_size <= 0 or batch_size > queue_capacity)):
if (batch_size is None) or (
(not isinstance(batch_size, ops.Tensor)) and
(batch_size <= 0 or batch_size >= queue_capacity)):
raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
(batch_size, queue_capacity))
if (read_batch_size is None) or (

View File

@ -112,6 +112,18 @@ class GraphIOTest(test.TestCase):
queue_capacity=queue_capacity,
num_threads=num_threads,
name=name)
self.assertRaisesRegexp(
ValueError,
"Invalid batch_size",
graph_io.read_batch_examples,
_VALID_FILE_PATTERN,
default_batch_size,
io_ops.TFRecordReader,
False,
num_epochs=None,
queue_capacity=default_batch_size,
num_threads=num_threads,
name=name)
self.assertRaisesRegexp(
ValueError,
"Invalid queue_capacity",
@ -356,7 +368,7 @@ class GraphIOTest(test.TestCase):
]
filename = self._create_temp_file("".join(json_lines))
batch_size = 10000
queue_capacity = 10000
queue_capacity = 100000
name = "my_large_batch"
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}

View File

@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
def tearDownModule():
gfile.DeleteRecursively(test.get_temp_dir())
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):

View File

@ -30,7 +30,7 @@ cuda_py_tests(
cuda_py_tests(
name = "linear_operator_addition_test",
size = "medium",
size = "small",
srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
additional_deps = [
":linalg_py",
@ -43,7 +43,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
cuda_py_tests(
@ -61,7 +60,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
cuda_py_tests(
@ -79,7 +77,6 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
shard_count = 5,
)
cuda_py_tests(
@ -96,7 +93,6 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
shard_count = 5,
)
cuda_py_tests(
@ -112,7 +108,6 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
cuda_py_tests(
@ -128,7 +123,6 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
cuda_py_tests(
@ -144,12 +138,11 @@ cuda_py_tests(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
cuda_py_tests(
name = "linear_operator_util_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/linear_operator_util_test.py"],
additional_deps = [
":linalg_py",
@ -160,7 +153,6 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
shard_count = 5,
)
py_library(

View File

@ -229,6 +229,29 @@ class MatmulWithBroadcastTest(test.TestCase):
self.assertAllEqual(expected, result)
class MatrixAdjointTest(test.TestCase):
def testNonBatchMatrix(self):
a = [[1, 2, 3j], [4, 5, -6j]] # Shape (2, 3)
expected = [[1, 4], [2, 5], [-3j, 6j]] # Shape (3, 2)
with self.test_session():
a_adj = linear_operator_util.matrix_adjoint(a)
self.assertEqual((3, 2), a_adj.get_shape())
self.assertAllClose(expected, a_adj.eval())
def testBatchMatrix(self):
matrix_0 = [[1j, 2, 3], [4, 5, 6]]
matrix_0_a = [[-1j, 4], [2, 5], [3, 6]]
matrix_1 = [[11, 22, 33], [44, 55, 66j]]
matrix_1_a = [[11, 44], [22, 55], [33, -66j]]
batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
expected_adj = [matrix_0_a, matrix_1_a] # Shape (2, 3, 2)
with self.test_session():
matrix_adj = linear_operator_util.matrix_adjoint(batch_matrix)
self.assertEqual((2, 3, 2), matrix_adj.get_shape())
self.assertAllEqual(expected_adj, matrix_adj.eval())
class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension):

View File

@ -289,6 +289,53 @@ def matmul_with_broadcast(a,
b_is_sparse=b_is_sparse)
def matrix_adjoint(a, name="matrix_adjoint"):
"""Transposes last two dimensions of tensor `a`, and takes complex conjugate.
If `a` is real valued, the result is equivalent to `matrix_transpose`.
For example:
```python
# Matrix with no batch dimension.
# 'x' is [[1 2 3j]
# [4 5 -6j]]
tf.matrix_adjoint(x) ==> [[1 4]
[2 5]
[-3j 6j]]
# Matrix with two batch dimensions.
# x.shape is [1, 2, 3, 4]
# tf.matrix_adjoint(x) is shape [1, 2, 4, 3]
```
Note that `tf.matmul` provides kwargs allowing for adjoint of arguments. This
is done with minimal cost, and is preferable to using this function. E.g.
```
# Good! Adjoint is taken at minimal additional cost.
tf.matmul(matrix, b, adjoint_b=True)
# Inefficient!
tf.matmul(matrix, tf.matrix_adjoint(b))
```
Args:
a: A `Tensor` with `rank >= 2`.
name: A name for the operation (optional).
Returns:
A batch matrix `Tensor` with same `dtype` as `a`.
Raises:
ValueError: If `a` is determined statically to have `rank < 2`.
"""
with ops.name_scope(name, values=[a]):
a = ops.convert_to_tensor(a, name="a")
a_transpose = array_ops.matrix_transpose(a)
return math_ops.conj(a_transpose)
def shape_tensor(shape, name=None):
"""Convert Tensor using default type, unless empty list or tuple."""
# Works just like random_ops._ShapeTensor.

View File

@ -0,0 +1,236 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A decoder that performs beam search.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
__all__ = [
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
"BeamSearchDecoder",
]
class BeamSearchDecoderOutput(
collections.namedtuple("BeamSearchDecoderOutput", ("rnn_output",))):
pass
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
("cell_state", "log_prob", "beam_ids"))):
pass
class BeamSearchDecoder(decoder.Decoder):
"""BeamSearch sampling decoder."""
def __init__(self, cell, embedding, start_tokens, end_token,
initial_state, beam_width, output_layer=None):
"""Initialize BeamSearchDecoder.
Args:
cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
if not isinstance(cell, core_rnn_cell.RNNCell):
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
if (output_layer is not None
and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._initial_cell_state = initial_state
self._output_layer = output_layer
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
@property
def batch_size(self):
return self._batch_size
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
size)
layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def output_size(self):
# Return the cell output and the id
prepend_beam_width = (
lambda s: tensor_shape.TensorShape([self._beam_width]).concatenate(s))
return BeamSearchDecoderOutput(
rnn_output=nest.map_structure(
prepend_beam_width, self._rnn_output_size()))
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
rnn_output=nest.map_structure(lambda _: dtype, self._rnn_output_size()))
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, first_inputs, initial_state)`.
"""
finished, first_inputs = self._finished, self._first_inputs
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=array_ops.zeros(
[self.batch_size, self.beam_width],
dtype=nest.flatten(self._initial_cell_state)[0].dtype),
beam_ids=tensor_array_ops.TensorArray(
size=0, dynamic_size=True, dtype=dtypes.int32,
clear_after_read=False))
return (finished, first_inputs, initial_state)
def _merge_batch_beams(self, t):
t_static_shape = t.shape
t_shape = array_ops.shape(t)
static_batch_size = tensor_util.constant_value(self._batch_size)
batch_size_beam_width = (
None if static_batch_size is None
else static_batch_size * self._beam_width)
reshaped_t = array_ops.reshape(
t, array_ops.concat(
([self._batch_size * self._beam_width], t_shape[2:]), 0))
reshaped_t.set_shape(
(tensor_shape.TensorShape([batch_size_beam_width])
.concatenate(t_static_shape[2:])))
return reshaped_t
def _split_batch_beams(self, t):
t_static_shape = t.shape
t_shape = array_ops.shape(t)
reshaped_t = array_ops.reshape(
t, array_ops.concat(
([self._batch_size, self._beam_width], t_shape[1:]), 0))
static_batch_size = tensor_util.constant_value(self._batch_size)
reshaped_t.set_shape(
(tensor_shape.TensorShape([static_batch_size, self._beam_width])
.concatenate(t_static_shape[1:])))
return reshaped_t
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
inputs = nest.map_structure(self._merge_batch_beams, inputs)
cell_state = nest.map_structure(self._merge_batch_beams, cell_state)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs)
next_cell_state = nest.map_structure(self._split_batch_beams,
next_cell_state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
# TODO(cinjon): Calculate next_log_probs, next_beam_ids,
# finished, next_inputs, final_cell_state via beam search
# via self._embedding
# ....
next_beam_ids, next_log_probs, final_cell_state, next_inputs, finished = (
None, None, None, None, None)
beam_ids = state.beam_ids.write(time, next_beam_ids)
outputs = BeamSearchDecoderOutput(cell_outputs)
next_state = BeamSearchDecoderState(
log_probs=next_log_probs,
beam_ids=beam_ids,
cell_state=final_cell_state)
return (outputs, next_state, next_inputs, finished)

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@ -38,34 +39,7 @@ from tensorflow.python.util import nest
__all__ = ["Decoder", "dynamic_decode"]
def _transpose_batch_time(x):
"""Transpose the batch and time dimensions of a Tensor.
Retains as much of the static shape information as possible.
Args:
x: A tensor of rank 2 or higher.
Returns:
x transposed along the first two dimensions.
Raises:
ValueError: if `x` is rank 1 or lower.
"""
x_static_shape = x.get_shape()
if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
raise ValueError(
"Expected input tensor %s to have rank at least 2, but saw shape: %s" %
(x, x_static_shape))
x_rank = array_ops.rank(x)
x_t = array_ops.transpose(
x, array_ops.concat(
([1, 0], math_ops.range(2, x_rank)), axis=0))
x_t.set_shape(
tensor_shape.TensorShape([
x_static_shape[1].value, x_static_shape[0].value
]).concatenate(x_static_shape[2:]))
return x_t
_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)

View File

@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
def tearDownModule():
gfile.DeleteRecursively(test.get_temp_dir())
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):

View File

@ -272,6 +272,7 @@ cc_library(
"lib/monitoring/sampler.h",
"lib/random/distribution_sampler.h",
"lib/random/philox_random.h",
"lib/random/random_distributions.h",
"lib/random/simple_philox.h",
"lib/strings/numbers.h",
"lib/strings/str_util.h",
@ -383,6 +384,7 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
"util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@ -1535,7 +1537,10 @@ cc_library(
tf_cuda_library(
name = "direct_session_internal",
srcs = ["common_runtime/direct_session.cc"],
hdrs = ["common_runtime/direct_session.h"],
hdrs = [
"common_runtime/direct_session.h",
"util/env_var.h",
],
copts = tf_copts(),
cuda_deps = [
":gpu_tracer",

View File

@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
@ -242,6 +243,13 @@ DirectSession::DirectSession(const SessionOptions& options,
thread_pools_.push_back(GlobalThreadPool(options));
owns_thread_pools_ = false;
}
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
Status status =
ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
// NOTE(mrry): We do not need to use a unique string for the session
// handle, because DirectSession owns its devices. This may change
// in future versions.
@ -448,7 +456,7 @@ Status DirectSession::Run(const RunOptions& run_options,
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
args.sync_on_finish = true;
args.sync_on_finish = sync_on_finish_;
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
@ -632,7 +640,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
args.sync_on_finish = true;
args.sync_on_finish = sync_on_finish_;
if (options_.config.graph_options().build_cost_model()) {
run_state->collector.reset(new StepStatsCollector(nullptr));

View File

@ -247,6 +247,8 @@ class DirectSession : public Session {
std::vector<thread::ThreadPool*> thread_pools_;
bool owns_thread_pools_ = false;
// If true, blocks until device has finished all queued operations in a step.
bool sync_on_finish_ = true;
// Schedules 'c' for execution on pool.
void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);

View File

@ -64,6 +64,10 @@ class ShapeRefiner {
return it->second.get();
}
// Getters and setters for graph_def_version_.
int32 graph_def_version() { return graph_def_version_; }
void set_graph_def_version(int32 version) { graph_def_version_ = version; }
private:
// Extracts the subgraph ending at 'node' that is statically
// computable and inserts into 'out_graph'. If statically computable,
@ -100,7 +104,7 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
const int graph_def_version_;
int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_;
// The lifetime of the tensors are bound to the runner, so it should be the

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
@ -48,6 +49,13 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env,
RendezvousMgrInterface* rendezvous_mgr)
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
Status status =
ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
}
GraphMgr::~GraphMgr() {
@ -486,7 +494,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
args.step_container = step_container;
args.sync_on_finish = true;
args.sync_on_finish = sync_on_finish_;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, handle);
}

View File

@ -137,6 +137,9 @@ class GraphMgr {
mutex mu_;
int64 next_id_ GUARDED_BY(mu_) = 0;
// If true, blocks until device has finished all queued operations in a step.
bool sync_on_finish_ = true;
// Table mapping graph handles to registered graphs.
//
// TODO(zhifengc): If the client does not call Deregister, we'll

View File

@ -873,6 +873,13 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
return Status::OK();
}
Status RandomShape(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
}
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.

View File

@ -199,6 +199,9 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Tested by ops/math_ops_test.cc.
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,

View File

@ -125,6 +125,23 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start,
}
}
Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
if (!IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().DebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
}
void AsyncOpKernel::Compute(OpKernelContext* context) {
Notification n;
ComputeAsync(context, [&n]() { n.Notify(); });

View File

@ -151,6 +151,10 @@ class OpKernel {
return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
}
// Turn a shape Tensor into a TensorShape
// TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
Status MakeShape(const Tensor& shape, TensorShape* out) const;
private:
const NodeDef def_;
const DataTypeVector input_types_;

View File

@ -239,8 +239,11 @@ string InferenceContext::DebugString() const {
ProtoDebugString(node_def_));
}
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
if (rank > kint32max) {
return errors::InvalidArgument("Rank cannot exceed kint32max");
}
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
@ -261,8 +264,11 @@ Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
existing);
}
Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
if (rank > kint32max) {
return errors::InvalidArgument("Rank cannot exceed kint32max");
}
const int32 existing = Rank(shape);
if (existing >= rank) {
*out = shape;
@ -276,8 +282,11 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
" but is rank ", existing);
}
Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank,
Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
if (rank > kint32max) {
return errors::InvalidArgument("Rank cannot exceed kint32max");
}
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
@ -470,12 +479,12 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
return ReturnCreatedShape(dims, out);
}
Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in,
Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
DimensionHandle new_dim, ShapeHandle* out) {
if (!RankKnown(s)) {
return ReturnUnknownShape(out);
}
int dim_index = dim_index_in;
int64 dim_index = dim_index_in;
if (dim_index < 0) {
dim_index = s->dims_.size() + dim_index;
}
@ -510,7 +519,8 @@ ShapeHandle InferenceContext::UnknownShape() {
return shape_manager_.UnknownShape();
}
ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) {
ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
std::vector<DimensionHandle> dims(rank);
for (int32 i = 0; i < rank; ++i) {
dims[i] = UnknownDim();

View File

@ -194,7 +194,7 @@ class InferenceContext {
return s;
}
ShapeHandle input(int idx) const { return inputs_[idx]; }
ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
int num_inputs() const { return inputs_.size(); }
@ -237,7 +237,7 @@ class InferenceContext {
// idx can be negative for an offset from end of dimensions.
// idx must be in the range [-1 * s.rank, s.rank).
DimensionHandle Dim(ShapeHandle s, int32 idx) {
DimensionHandle Dim(ShapeHandle s, int64 idx) {
if (s->rank_ == kUnknownRank) {
return UnknownDim();
}
@ -277,11 +277,11 @@ class InferenceContext {
// the shape with asserted rank in <*out>. Otherwise return an error.
//
// Note that <*out> may be set to <shape>.
Status WithRank(ShapeHandle shape, int32 rank,
Status WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
Status WithRankAtLeast(ShapeHandle shape, int32 rank,
Status WithRankAtLeast(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
Status WithRankAtMost(ShapeHandle shape, int32 rank,
Status WithRankAtMost(ShapeHandle shape, int64 rank,
ShapeHandle* out) TF_MUST_USE_RESULT;
// If <dim> has value <value>, or its value is unknown, returns OK and returns
@ -332,7 +332,7 @@ class InferenceContext {
// Returns in <out> the shape from replacing <s.dim[dim_index]> with
// <new_dim>.
Status ReplaceDim(ShapeHandle s, int dim_index, DimensionHandle new_dim,
Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns a new shape with the given dims. The returned value is owned by
@ -344,7 +344,7 @@ class InferenceContext {
ShapeHandle UnknownShape();
// Returns a shape with specified rank but unknown dims.
ShapeHandle UnknownShapeOfRank(int32 rank);
ShapeHandle UnknownShapeOfRank(int64 rank);
// Returns a new shape of zero dimensions.
ShapeHandle Scalar();

View File

@ -839,11 +839,6 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
if (refiner == nullptr) {
refiner = &default_refiner;
}
if (!opts.return_tensors.empty()) {
if (return_tensors == nullptr) {
return errors::InvalidArgument(
@ -857,6 +852,36 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
return_tensors->size(), ")");
}
}
ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
if (refiner == nullptr) {
refiner = &default_refiner;
} else {
// Log a warning if we are importing a GraphDef at an older
// producer version after already having added non-source/sink
// nodes to the graph in the past.
if (gdef.versions().producer() > 0 &&
gdef.versions().producer() < refiner->graph_def_version() &&
g->num_nodes() > 2) {
LOG(WARNING) << "Importing a graph with a lower producer version "
<< gdef.versions().producer()
<< " into an existing graph with producer version "
<< refiner->graph_def_version() << ". Shape inference will "
<< "have run different parts of the graph with different "
<< "producer versions.";
}
}
// Set the graph def version of the refiner as the min of the
// current value and the version from the graph we are about to
// import.
//
// Note: to match Run() semantics, we should re-run shape inference
// on the entire graph if the producer version has changed. For now
// we log the warning above.
refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer()));
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors);
}

View File

@ -2271,5 +2271,176 @@ TEST_F(GraphConstructorTest, GraphDefVersionMergingDuringImport) {
EXPECT_EQ(3, graph_.versions().bad_consumers(2));
}
TEST_F(GraphConstructorTest, ImportGraphDefProvidedShapeRefinerVersions) {
ImportGraphDefOptions opts;
// A valid graph at producer version 20, but one
// that would not import if the graph_def_version were 21.
string gdef_ascii = strings::StrCat(R"EOF(
node {
name: "Sum/input"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 1
}
}
tensor_content: "\001\000\000\000\002\000\000\000"
}
}
}
}
node {
name: "Sum/reduction_indices"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 1
}
}
tensor_content: "\000\000\000\000\001\000\000\000"
}
}
}
}
node {
name: "Sum"
op: "Sum"
input: "Sum/input"
input: "Sum/reduction_indices"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
attr {
key: "keep_dims"
value {
b: false
}
}
}
versions {
producer: 20
})EOF");
// Create a shape refiner with the latest TF_GRAPH_DEF_VERSION.
// Importing the graphdef with an existing refiner should
// make the refiner inherit the graphdef version from the
// passed in graphdef since it has a lower producer.
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
ExpectOK(gdef_ascii, opts, &refiner);
// Add another node with a higher producer
gdef_ascii = strings::StrCat(R"EOF(
node {
name: "RandomConst"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 1
}
}
tensor_content: "\001\000\000\000\002\000\000\000"
}
}
}
}
versions {
producer: 21
})EOF");
ExpectOK(gdef_ascii, opts, &refiner);
// Check that the refiner's graph def version is the lowest of
// the graph defs we have seen so far.
EXPECT_EQ(20, refiner.graph_def_version());
// Add another node with a lower producer
gdef_ascii = strings::StrCat(R"EOF(
node {
name: "RandomConst2"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 1
}
}
tensor_content: "\001\000\000\000\002\000\000\000"
}
}
}
}
versions {
producer: 17
})EOF");
ExpectOK(gdef_ascii, opts, &refiner);
// Check that the refiner's graph def version is the lowest of
// the graph defs we have seen so far.
EXPECT_EQ(17, refiner.graph_def_version());
}
} // namespace
} // namespace tensorflow

View File

@ -29,6 +29,16 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "op_types",
srcs = ["op_types.cc"],
hdrs = ["op_types.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "utils",
srcs = ["utils.cc"],
@ -88,6 +98,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":grappler_item",
":op_types",
":utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib_internal",

View File

@ -53,6 +53,22 @@ int GetNumAvailableGPUs() {
return num_eligible_gpus;
}
int64 AvailableGPUMemory(int gpu_id) {
#if GOOGLE_CUDA
// Look up the device, to see its attributes.
perftools::gputools::Platform* gpu_platform = GPUMachineManager();
CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
perftools::gputools::StreamExecutor* se =
gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
return available_memory;
#else
return 0;
#endif
}
int GetNumAvailableLogicalCPUCores() { return port::NumSchedulableCPUs(); }
} // end namespace grappler

View File

@ -29,6 +29,10 @@ namespace grappler {
// than 8.
int GetNumAvailableGPUs();
// Maximum amount of gpu memory available per gpu. gpu_id must be in the range
// [0, num_available_gpu)
int64 AvailableGPUMemory(int gpu_id);
// Get the number of logical CPU cores (aka hyperthreads) available.
int GetNumAvailableLogicalCPUCores();

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/grappler/inputs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@ -90,7 +91,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
node.clear_device();
}
if (node.op() == "Placeholder" || node.op() == "PlaceholderV2") {
if (IsPlaceholder(node)) {
if (node.attr().count("dtype") == 0) {
LOG(ERROR) << "Unknown type for placeholder " << node.name()
<< ", skipping this input";

View File

@ -0,0 +1,27 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/op_types.h"
namespace tensorflow {
namespace grappler {
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2";
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,29 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_OP_TYPES_H_
#define TENSORFLOW_GRAPPLER_OP_TYPES_H_
#include "tensorflow/core/framework/node_def.pb.h"
namespace tensorflow {
namespace grappler {
bool IsPlaceholder(const NodeDef& node);
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_OP_TYPES_H_

View File

@ -25,6 +25,40 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "auto_parallel",
srcs = ["auto_parallel.cc"],
hdrs = [
"auto_parallel.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
cc_test(
name = "auto_parallel_test",
srcs = ["auto_parallel_test.cc"],
deps = [
":auto_parallel",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)
cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
@ -179,6 +213,7 @@ cc_library(
":constant_folding",
":graph_optimizer",
":layout_optimizer",
":memory_optimizer",
":model_pruner",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -0,0 +1,260 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace grappler {
const char kAutoParallelPrefix[] = "AutoParallel";
NodeDef* AutoParallel::AddNodeDivConst() {
NodeDef* node = graph_.add_node();
node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
node->set_op("Const");
AttrValue attr_data_type;
attr_data_type.set_type(DT_FLOAT);
node->mutable_attr()->insert({"dtype", attr_data_type});
AttrValue attr_tensor;
auto tensor = attr_tensor.mutable_tensor();
tensor->add_float_val(static_cast<float>(num_replicas_));
tensor->set_dtype(DT_FLOAT);
node->mutable_attr()->insert({"value", attr_tensor});
return node;
}
NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
const string& input_b) {
NodeDef* node = graph_.add_node();
node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
node->set_op("RealDiv");
node->add_input(input_a);
node->add_input(input_b);
AttrValue attr_type;
attr_type.set_type(DT_FLOAT);
node->mutable_attr()->insert({"T", attr_type});
return node;
}
NodeDef* AutoParallel::AddNodeControl(const string& name,
const std::set<string>& deps,
GraphDef* graph) {
NodeDef* node = graph->add_node();
node->set_name(name);
node->set_op("NoOp");
for (const auto& dep : deps) {
node->add_input(strings::StrCat("^", dep));
}
return node;
}
Status AutoParallel::Initialize(const GrapplerItem& item) {
num_gpus_ = GetNumAvailableGPUs();
LOG(INFO) << "Number of GPUs: " << num_gpus_;
item_ = &item;
graph_ = item.graph;
LOG(INFO) << "Original graph size: " << graph_.node_size();
if (item.fetch.empty()) {
return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
}
if (item.MainVariables().empty()) {
return Status(error::INVALID_ARGUMENT, "No variables provided.");
}
for (const auto& init : item.init_ops) {
VLOG(1) << "Init node: " << init;
}
for (const auto& fetch : item.fetch) {
VLOG(1) << "Fetch node: " << fetch;
}
for (const auto& var : item.MainVariables()) {
VLOG(2) << "Variable: " << var->name();
}
std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
"ApplyProximalGradientDescent",
"ApplyAdadelta",
"ApplyAdagrad",
"ApplyProximalAdagrad",
"ApplyAdagradDA",
"ApplyFtrl",
"ApplyMomentum",
"ApplyAdam",
"ApplyRMSProp",
"ApplyCenteredRMSProp"};
const NodeDef* dequeue_node = nullptr;
for (int i = 0; i < graph_.node_size(); i++) {
all_nodes_.insert(
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
if (graph_.node(i).op() == "QueueDequeueManyV2") {
dequeue_node = graph_.mutable_node(i);
}
if (apply_gradients_ops.find(graph_.node(i).op()) !=
apply_gradients_ops.end()) {
apply_gradients_nodes_.insert(graph_.node(i).name());
VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
}
}
auto div_const_node = AddNodeDivConst();
all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
{"ApplyProximalGradientDescent", 4},
{"ApplyAdadelta", 6},
{"ApplyAdagrad", 3},
{"ApplyProximalAdagrad", 5},
{"ApplyAdagradDA", 3},
{"ApplyFtrl", 3},
{"ApplyMomentum", 3},
{"ApplyAdam", 9},
{"ApplyRMSProp", 7},
{"ApplyCenteredRMSProp", 8}};
for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
auto div_node = AddNodeDiv(
apply_gradient_node_name,
apply_gradients_node->input(gradient_pos[apply_gradients_op]),
div_const_node->name());
all_nodes_.insert(std::make_pair(div_node->name(), div_node));
*apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
div_node->name();
}
LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
LOG(INFO) << "Number of training nodes: " << train_nodes.size();
std::vector<const NodeDef*> input_nodes;
if (dequeue_node) {
LOG(INFO) << "Dequeue node: " << dequeue_node->name();
input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
}
LOG(INFO) << "Number of input nodes: " << input_nodes.size();
std::set<string> dont_replicate_nodes;
for (const auto& variable : item.MainVariables()) {
dont_replicate_nodes.insert(variable->name());
}
// Don't replicate all input nodes, except the dequeue node.
for (const auto& input_node : input_nodes) {
if (input_node->name() != dequeue_node->name()) {
dont_replicate_nodes.insert(input_node->name());
}
}
for (const auto& node : train_nodes) {
if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
replica_nodes_.insert(node->name());
}
}
LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
for (const auto& node : all_nodes_) {
if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
shared_nodes_.insert(node.first);
}
}
LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
return Status::OK();
}
bool AutoParallel::NotSharedNode(const string& name) {
return shared_nodes_.find(name) == shared_nodes_.end();
}
void AutoParallel::AddSharedNodes(GraphDef* graph) {
string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
for (const auto& node : shared_nodes_) {
auto new_node = graph->add_node();
*new_node = *all_nodes_[node];
for (int i = 0; i < new_node->input_size(); i++) {
if (NotSharedNode(NodeName(new_node->input(i)))) {
string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
*new_node->mutable_input(i) = new_name;
}
}
}
}
void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
for (const auto& node : replica_nodes_) {
auto new_node = graph->add_node();
*new_node = *all_nodes_[node];
if (NotSharedNode(new_node->name())) {
new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
if (num_gpus_ > 0) {
new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
}
for (int i = 0; i < new_node->input_size(); i++) {
if (NotSharedNode(NodeName(new_node->input(i)))) {
string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
*new_node->mutable_input(i) = new_name;
}
}
}
}
}
void AutoParallel::BuildGraph(GraphDef* graph) {
AddSharedNodes(graph);
for (int i = 0; i < num_replicas_; i++) {
AddOneReplica(graph, i);
}
std::set<string> fetches;
for (int i = 0; i < item_->fetch.size(); i++) {
for (int j = 0; j < num_replicas_; j++) {
string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
fetches.insert(fetch);
}
}
string name_control =
strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
auto control = AddNodeControl(name_control, fetches, graph);
for (const auto& fetch : item_->fetch) {
AddNodeControl(fetch, {control->name()}, graph);
}
LOG(INFO) << "Parallelized graph size: " << graph->node_size();
}
Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
TF_RETURN_IF_ERROR(Initialize(item));
BuildGraph(output);
return Status::OK();
}
void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) {
// TODO(yaozhang): Add feedback.
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Automatically parallelize a graph by splitting in the batch dimension.
class AutoParallel : public GraphOptimizer {
public:
AutoParallel(int num_replicas) : num_replicas_(num_replicas) {}
~AutoParallel() override {}
string name() const override { return "autoparallel"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
private:
GraphDef graph_;
std::map<string, NodeDef*> all_nodes_;
std::set<string> apply_gradients_nodes_;
std::set<string> replica_nodes_;
std::set<string> shared_nodes_;
const GrapplerItem* item_;
int num_replicas_;
int num_gpus_;
Status Initialize(const GrapplerItem& item);
NodeDef* AddNodeDivConst();
NodeDef* AddNodeDiv(const string& name, const string& input_a,
const string& input_b);
NodeDef* AddNodeControl(const string& name, const std::set<string>& deps,
GraphDef* graph);
bool NotSharedNode(const string& name);
void AddSharedNodes(GraphDef* graph);
void AddOneReplica(GraphDef* graph, int number);
void BuildGraph(GraphDef* graph);
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_

View File

@ -0,0 +1,125 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class AutoParallelTest : public ::testing::Test {};
TEST_F(AutoParallelTest, SimpleParallel) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
{constant_b}, {DT_FLOAT});
Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]});
Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
Output apply_gradient = ops::ApplyGradientDescent(
s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
GrapplerItem item;
item.init_ops.push_back("assign");
item.fetch.push_back("apply_gradient");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
AutoParallel parallel(2);
GraphDef output;
Status status = parallel.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(20, output.node_size());
const NodeDef& node_assign = output.node(0);
EXPECT_EQ("assign", node_assign.name());
EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_assign.input(1));
const NodeDef& node_constant_b = output.node(1);
EXPECT_EQ("constant_b", node_constant_b.name());
const NodeDef& node_fifo_queue = output.node(2);
EXPECT_EQ("fifo_queue", node_fifo_queue.name());
const NodeDef& node_var = output.node(3);
EXPECT_EQ("var", node_var.name());
const NodeDef& node_div_const0 = output.node(4);
EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-Const",
node_div_const0.name());
const NodeDef& node_div0 = output.node(5);
EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-apply_gradient",
node_div0.name());
const NodeDef& node_add0 = output.node(6);
EXPECT_EQ("AutoParallel-Replica-0-add", node_add0.name());
const NodeDef& node_gradient0 = output.node(7);
EXPECT_EQ("AutoParallel-Replica-0-apply_gradient", node_gradient0.name());
const NodeDef& node_constant_a0 = output.node(8);
EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_constant_a0.name());
const NodeDef& node_dequeue0 = output.node(9);
EXPECT_EQ("AutoParallel-Replica-0-dequeue", node_dequeue0.name());
const NodeDef& node_learning_rate0 = output.node(10);
EXPECT_EQ("AutoParallel-Replica-0-learning_rate", node_learning_rate0.name());
const NodeDef& node_div_const1 = output.node(11);
EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-Const",
node_div_const1.name());
const NodeDef& node_div1 = output.node(12);
EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-apply_gradient",
node_div1.name());
const NodeDef& node_add1 = output.node(13);
EXPECT_EQ("AutoParallel-Replica-1-add", node_add1.name());
const NodeDef& node_gradient1 = output.node(14);
EXPECT_EQ("AutoParallel-Replica-1-apply_gradient", node_gradient1.name());
const NodeDef& node_constant_a1 = output.node(15);
EXPECT_EQ("AutoParallel-Replica-1-constant_a", node_constant_a1.name());
const NodeDef& node_dequeue1 = output.node(16);
EXPECT_EQ("AutoParallel-Replica-1-dequeue", node_dequeue1.name());
const NodeDef& node_learning_rate1 = output.node(17);
EXPECT_EQ("AutoParallel-Replica-1-learning_rate", node_learning_rate1.name());
const NodeDef& node_fetch = output.node(18);
EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
EXPECT_EQ("^AutoParallel-Replica-0-apply_gradient", node_fetch.input(0));
EXPECT_EQ("^AutoParallel-Replica-1-apply_gradient", node_fetch.input(1));
const NodeDef& node_gradient = output.node(19);
EXPECT_EQ("apply_gradient", node_gradient.name());
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/lib/core/status.h"
@ -37,6 +38,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
if (optimizer == "layout") {
graph_optimizer.reset(new LayoutOptimizer());
}
if (optimizer == "memory") {
graph_optimizer.reset(new MemoryOptimizer());
}
return graph_optimizer;
}
@ -55,8 +59,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() > 0) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer()));
}
} else {
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout",
"memory"};
for (const auto& optimizer : cfg_.optimizers()) {
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
optimizers.push_back(NewOptimizer(optimizer));
@ -81,7 +90,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
}
}
// Copy the graph version.
*optimized_graph->mutable_versions() = item.graph.versions();

View File

@ -97,6 +97,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
OP_REQUIRES(
context, block_shape_product > 0,
errors::InvalidArgument("Product of block sizes must be positive, got ",
block_shape_product));
const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
OP_REQUIRES(

View File

@ -216,12 +216,14 @@ struct CropAndResize<CPUDevice, T> {
const float x_lerp = in_x - left_x_index;
for (int d = 0; d < depth; ++d) {
const float top_left(image(b_in, top_y_index, left_x_index, d));
const float top_right(image(b_in, top_y_index, right_x_index, d));
const float bottom_left(
image(b_in, bottom_y_index, left_x_index, d));
const float bottom_right(
image(b_in, bottom_y_index, right_x_index, d));
const float top_left(
static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
const float top_right(
static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
const float bottom_left(static_cast<float>(
image(b_in, bottom_y_index, left_x_index, d)));
const float bottom_right(static_cast<float>(
image(b_in, bottom_y_index, right_x_index, d)));
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom =
bottom_left + (bottom_right - bottom_left) * x_lerp;
@ -545,12 +547,14 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float x_lerp = in_x - left_x_index;
for (int d = 0; d < depth; ++d) {
const float top_left(image(b_in, top_y_index, left_x_index, d));
const float top_right(image(b_in, top_y_index, right_x_index, d));
const float bottom_left(
image(b_in, bottom_y_index, left_x_index, d));
const float bottom_right(
image(b_in, bottom_y_index, right_x_index, d));
const float top_left(
static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
const float top_right(
static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
const float bottom_left(static_cast<float>(
image(b_in, bottom_y_index, left_x_index, d)));
const float bottom_right(static_cast<float>(
image(b_in, bottom_y_index, right_x_index, d)));
// Compute the image gradient.
float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
x_lerp * (bottom_right - top_right);
@ -606,18 +610,25 @@ inline void CheckValidBoxInd<CPUDevice>(
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>);
TF_CALL_half(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@ -685,7 +696,7 @@ inline void CheckValidBoxInd<GPUDevice>(
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<GPUDevice, T>);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL

View File

@ -88,26 +88,26 @@ __global__ void CropAndResizeKernel(
const int right_x_index = ceilf(in_x);
const float x_lerp = in_x - left_x_index;
const float top_left(
const float top_left(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
left_x_index) *
depth +
d]);
const float top_right(
d]));
const float top_right(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
right_x_index) *
depth +
d]);
const float bottom_left(
d]));
const float bottom_left(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
left_x_index) *
depth +
d]);
const float bottom_right(
d]));
const float bottom_right(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
right_x_index) *
depth +
d]);
d]));
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
crops_ptr[out_idx] = top + (bottom - top) * y_lerp;
@ -258,26 +258,26 @@ __global__ void CropAndResizeBackpropBoxesKernel(
const int right_x_index = ceilf(in_x);
const float x_lerp = in_x - left_x_index;
const float top_left =
const float top_left(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
left_x_index) *
depth +
d];
const float top_right =
d]));
const float top_right(static_cast<float>(
image_ptr[((b_in * image_height + top_y_index) * image_width +
right_x_index) *
depth +
d];
const float bottom_left =
d]));
const float bottom_left(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
left_x_index) *
depth +
d];
const float bottom_right =
d]));
const float bottom_right(static_cast<float>(
image_ptr[((b_in * image_height + bottom_y_index) * image_width +
right_x_index) *
depth +
d];
d]));
// Compute the image gradient.
float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
@ -436,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
TF_CALL_float(DEFINE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
@ -31,9 +32,10 @@ namespace tensorflow {
class CropAndResizeOpTest : public OpsTestBase {
protected:
template <typename T>
void MakeOp(float extrapolation_value) {
TF_EXPECT_OK(NodeDefBuilder("crop_and_resize_op", "CropAndResize")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DataTypeToEnum<T>::value))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32))
@ -43,12 +45,33 @@ class CropAndResizeOpTest : public OpsTestBase {
}
};
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) {
MakeOp(0);
#define REGISTER_TEST(T) \
TEST_F(CropAndResizeOpTest, TestCropAndResize##T) { \
MakeOp<T>(0); \
AddInputFromArray<T>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); \
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); \
AddInputFromArray<int32>(TensorShape({1}), {0}); \
AddInputFromArray<int32>(TensorShape({2}), {1, 1}); \
TF_ASSERT_OK(RunOpKernel()); \
\
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); \
test::FillValues<float>(&expected, {2.5}); \
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); \
}
REGISTER_TEST(float)
REGISTER_TEST(double)
REGISTER_TEST(int8)
REGISTER_TEST(uint8)
#undef REGISTER_TEST
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8) {
MakeOp<uint8>(0);
// Input:
// 1, 2
// 3, 4
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<uint8>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {0});
AddInputFromArray<int32>(TensorShape({2}), {1, 1});
@ -60,7 +83,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@ -76,7 +99,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@ -97,7 +120,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@ -118,7 +141,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2, 3
// 4, 5, 6
@ -143,7 +166,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2, 3
// 4, 5, 6
@ -169,7 +192,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
const float v = -1;
MakeOp(v);
MakeOp<float>(v);
// Input:
// 1, 2
// 3, 4
@ -190,7 +213,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
MakeOp(0);
MakeOp<float>(0);
// Input:
// 1, 2
// 3, 4
@ -208,7 +231,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
}
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
MakeOp(0);
MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {0});
@ -220,7 +243,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
MakeOp(0);
MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
@ -233,7 +256,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
MakeOp(0);
MakeOp<float>(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {1});

View File

@ -38,6 +38,7 @@ namespace functor {
DECLARE_GPU_SPECS_INDEX(T, int64)
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_complex64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX

View File

@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
DEFINE_GPU_SPECS_INDEX(T, int64);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
TF_CALL_complex64(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
#undef DEFINE_GPU_SPECS_INDEX

View File

@ -114,6 +114,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
TF_CALL_complex64(REGISTER_GATHER_GPU);
#undef REGISTER_GATHER_GPU

View File

@ -40,9 +40,9 @@ namespace {
class GatherOpTest : public OpsTestBase {
protected:
void MakeOp(DataType index_type) {
void MakeOp(DataType data_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "Gather")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(data_type))
.Input(FakeInput(index_type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
@ -50,7 +50,7 @@ class GatherOpTest : public OpsTestBase {
};
TEST_F(GatherOpTest, ScalarIndices) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
@ -63,8 +63,26 @@ TEST_F(GatherOpTest, ScalarIndices) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(GatherOpTest, ScalarIndices_Complex) {
MakeOp(DT_COMPLEX64, DT_INT32);
// Feed and run
AddInputFromArray<std::complex<float>>(
TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
std::complex<float>(2, 12), std::complex<float>(3, 13),
std::complex<float>(4, 14)});
AddInputFromArray<int32>(TensorShape({}), {3});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_COMPLEX64, TensorShape({}));
test::FillValues<std::complex<float>>(&expected,
{std::complex<float>(3, 13)});
test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
}
TEST_F(GatherOpTest, Simple_TwoD32) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@ -79,7 +97,7 @@ TEST_F(GatherOpTest, Simple_TwoD32) {
}
TEST_F(GatherOpTest, ZeroSize_TwoD32) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 0}), {});
@ -92,7 +110,7 @@ TEST_F(GatherOpTest, ZeroSize_TwoD32) {
}
TEST_F(GatherOpTest, Simple_TwoD64) {
MakeOp(DT_INT64);
MakeOp(DT_FLOAT, DT_INT64);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@ -107,7 +125,7 @@ TEST_F(GatherOpTest, Simple_TwoD64) {
}
TEST_F(GatherOpTest, HighRank) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3});
@ -121,7 +139,7 @@ TEST_F(GatherOpTest, HighRank) {
}
TEST_F(GatherOpTest, Error_IndexOutOfRange) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),

View File

@ -295,6 +295,83 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) {
reinterpret_cast<const float*>(output_tensor.flat<float>().data()));
}
static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
const GraphTransferInfo& gfi1) {
LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
<< gfi1.const_node_info_size();
// 1. check node_info
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
for (int i = 0; i < gfi0.node_info_size(); ++i) {
const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i);
const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i);
EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
EXPECT_EQ(ni0.ByteSize(), ni1.ByteSize());
}
// 2. check const_node_info
ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i);
const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i);
ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
for (int j = 0; j < cni0.shape_size(); ++j) {
EXPECT_EQ(cni0.shape(j), cni1.shape(j));
}
EXPECT_EQ(cni0.ByteSize(), cni1.ByteSize());
EXPECT_EQ(cni0.DebugString(), cni1.DebugString());
}
// 3. check node_input_info
ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i);
const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i);
EXPECT_EQ(nii0.ByteSize(), nii1.ByteSize());
EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
}
// 4. check node_output_info
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i);
const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i);
ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
}
EXPECT_EQ(noi0.ByteSize(), noi1.ByteSize());
EXPECT_EQ(noi0.DebugString(), noi1.DebugString());
}
// 5. check graph_input_node_info
ASSERT_EQ(gfi0.graph_input_node_info_size(),
gfi1.graph_input_node_info_size());
for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
const GraphTransferInfo::GraphInputNodeInfo& gini0 =
gfi0.graph_input_node_info(i);
const GraphTransferInfo::GraphInputNodeInfo& gini1 =
gfi0.graph_input_node_info(i);
EXPECT_EQ(gini0.ByteSize(), gini1.ByteSize());
EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
}
// 6. check graph_output_node_info
ASSERT_EQ(gfi0.graph_output_node_info_size(),
gfi1.graph_output_node_info_size());
for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
const GraphTransferInfo::GraphOutputNodeInfo& goni0 =
gfi0.graph_output_node_info(i);
const GraphTransferInfo::GraphOutputNodeInfo& goni1 =
gfi0.graph_output_node_info(i);
EXPECT_EQ(goni0.ByteSize(), goni1.ByteSize());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
}
// 7. check destination
EXPECT_EQ(gfi0.destination(), gfi1.destination());
}
// CAVEAT: This test only runs when you specify hexagon library using
// makefile.
// CAVEAT: This test is disabled by default because hexagon can keep only
@ -450,34 +527,22 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
prof1.Stop();
prof1.DumpStatistics("Estiame shape by shape inference");
LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
<< gfi1.const_node_info_size();
CompareGraphTransferInfo(gfi0, gfi1);
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
const RemoteFusedGraphExecuteInfo ei0 =
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi0);
const RemoteFusedGraphExecuteInfo ei1 =
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi1);
ASSERT_EQ(gt0.GetGraphTransferInfo().const_node_info_size(),
gt1.GetGraphTransferInfo().const_node_info_size());
GraphTransferInfo rgfi0;
rgfi0.ParseFromString(ei0.serialized_executor_parameters());
GraphTransferInfo rgfi1;
rgfi1.ParseFromString(ei1.serialized_executor_parameters());
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
const GraphTransferInfo::ConstNodeInfo& ni0 = gfi0.const_node_info(i);
const GraphTransferInfo::ConstNodeInfo& ni1 = gfi1.const_node_info(i);
ASSERT_EQ(ni0.shape_size(), ni1.shape_size());
for (int j = 0; j < ni0.shape_size(); ++j) {
EXPECT_EQ(ni0.shape(j), ni1.shape(j));
}
}
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
const GraphTransferInfo::NodeOutputInfo& no0 = gfi0.node_output_info(i);
const GraphTransferInfo::NodeOutputInfo& no1 = gfi1.node_output_info(i);
ASSERT_EQ(no0.max_byte_size_size(), no1.max_byte_size_size());
for (int j = 0; j < no0.max_byte_size_size(); ++j) {
EXPECT_EQ(no0.max_byte_size(j), no1.max_byte_size(j));
}
}
CompareGraphTransferInfo(rgfi0, rgfi1);
CompareGraphTransferInfo(gfi0, rgfi0);
CompareGraphTransferInfo(gfi1, rgfi1);
}
#endif
} // namespace tensorflow

View File

@ -174,6 +174,7 @@ const std::unordered_map<string, SupportedOpType> OP_NAME_TO_SOC_OP_TYPE_MAP{
{"Placeholder", SupportedOpType::NOP},
{"RequantizationRange", SupportedOpType::REQUANTIZATION_RANGE_32},
{"Requantize", SupportedOpType::REQUANTIZE_32_TO_8},
{"QuantizedReshape", SupportedOpType::QUANTIZED_RESHAPE},
};
/* static */ const IGraphTransferOpsDefinitions&

View File

@ -587,8 +587,8 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{2}, 0, tensor_out.shape(), &output));
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_out.shape(), &output));
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};

View File

@ -70,7 +70,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
dtype maxval = -FLT_MAX;
dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * channels * height * width;
for (int h = hstart; h < hend; ++h) {
@ -312,9 +312,6 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool output.
// This is equal to Hout*Wout*C.
// bottom_diff: the gradient of the gradient w.r.t. output.
// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
// the kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
@ -357,12 +354,12 @@ bool MaxPoolBackwardNoMask<T>::operator()(
const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int bottom_size = batch * channels * height * width;
const int top_size = batch * channels * pooled_height * pooled_width;
const int bottom_size = batch * channels * height * width;
SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
const int top_size = batch * channels * pooled_height * pooled_width;
MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(

Some files were not shown because too many files have changed in this diff Show More